def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path, eval_data_path): def choose_train(x): return True # NOTE # if FLAGS.train_genre is not None: # def choose_train(x): return x.get('genre') == FLAGS.train_genre def choose_eval(x): return True # NOTE # if FLAGS.eval_genre is not None: # def choose_eval(x): return x.get('genre') == FLAGS.eval_genre if not FLAGS.expanded_eval_only_mode: if FLAGS.data_type == "nli": # Load the data. raw_training_data = data_manager.load_data( training_data_path, FLAGS.lowercase, choose_train, mode=FLAGS.transition_mode, tree_joint=FLAGS.tree_joint, distance_type=FLAGS.distance_type) # only for nli data now else: # Load the data. raw_training_data = data_manager.load_data( training_data_path, FLAGS.lowercase, mode=FLAGS.transition_mode) else: raw_training_data = None if FLAGS.data_type == "nli": # Load the eval data. raw_eval_sets = [] for path in eval_data_path.split(':'): raw_eval_data = data_manager.load_data( path, FLAGS.lowercase, choose_eval, mode=FLAGS.transition_mode, tree_joint=FLAGS.tree_joint, distance_type=FLAGS.distance_type) # only for nli data now raw_eval_sets.append((path, raw_eval_data)) # print raw_eval_data[1].keys() #exit(1) else: # Load the eval data. raw_eval_sets = [] for path in eval_data_path.split(':'): raw_eval_data = data_manager.load_data(path, FLAGS.lowercase, mode=FLAGS.transition_mode) raw_eval_sets.append((path, raw_eval_data)) # Prepare the vocabulary. if not data_manager.FIXED_VOCABULARY: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: vocabulary = data_manager.FIXED_VOCABULARY logger.Log("In fixed vocabulary mode. Training embeddings.") # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_cropping, pad_from_left=pad_from_left(), tree_joint=FLAGS.tree_joint) if raw_training_data is not None else None training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA # ,train_seed=FLAGS.train_seed ) if raw_training_data is not None else None # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_eval_cropping, pad_from_left=pad_from_left(), tree_joint=FLAGS.tree_joint) eval_it = util.MakeEvalIterator( eval_data, FLAGS.batch_size * FLAGS.sample_num, # keep the eval running speed FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) return vocabulary, initial_embeddings, training_data_iter, eval_iterators
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path, eval_data_path): def choose_train(x): return True if FLAGS.train_genre is not None: def choose_train(x): return x.get('genre') == FLAGS.train_genre def choose_eval(x): return True if FLAGS.eval_genre is not None: def choose_eval(x): return x.get('genre') == FLAGS.eval_genre if not FLAGS.expanded_eval_only_mode: raw_training_data = data_manager.load_data(training_data_path, FLAGS.lowercase, eval_mode=False) else: raw_training_data = None raw_eval_sets = [] for path in eval_data_path.split(':'): raw_eval_data = data_manager.load_data(path, FLAGS.lowercase, choose_eval, eval_mode=True) raw_eval_sets.append((path, raw_eval_data)) # Prepare the vocabulary. if not data_manager.FIXED_VOCABULARY: vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: vocabulary = data_manager.FIXED_VOCABULARY logger.Log( "In fixed vocabulary mode. Training embeddings from scratch.") # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") if raw_training_data is not None: training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_cropping, pad_from_left=pad_from_left( )) if raw_training_data is not None else None training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA ) if raw_training_data is not None else None training_data_length = len(training_data[0]) else: training_data_iter = None training_data_length = 0 # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_eval_cropping, pad_from_left=pad_from_left()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) return vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length
def run(only_forward=False): logger = afs_safe_logger.Logger( os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): eval_data, _ = data_manager.load_data(eval_filename) raw_eval_sets.append((eval_filename, eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromASCII( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") training_data_iter = util.MakeTrainingIterator(training_data, FLAGS.batch_size) eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") eval_iterators.append( (filename, util.MakeEvalIterator( (e_X, e_transitions, e_y, e_num_transitions), FLAGS.batch_size))) # Set up the placeholders. y = T.vector("y", dtype="int32") lr = T.scalar("lr") training_mode = T.scalar( "training_mode") # 1: Training with dropout, 0: Eval ground_truth_transitions_visible = T.scalar( "ground_truth_transitions_visible", dtype="int32") logger.Log("Building model.") vs = util.VariableStore(default_initializer=util.UniformInitializer( FLAGS.init_range), logger=logger) if FLAGS.model_type == "CBOW": model_cls = spinn.cbow.CBOW elif FLAGS.model_type == "RNN": model_cls = spinn.plain_rnn.RNN else: model_cls = getattr(spinn.fat_stack, FLAGS.model_type) # Generator of mask for scheduled sampling numpy_random = np.random.RandomState(1234) ss_mask_gen = T.shared_randomstreams.RandomStreams( numpy_random.randint(999999)) # Training step number ss_prob = T.scalar("ss_prob") if data_manager.SENTENCE_PAIR_DATA: X = T.itensor3("X") transitions = T.itensor3("transitions") num_transitions = T.imatrix("num_transitions") predicted_premise_transitions, predicted_hypothesis_transitions, logits = build_sentence_pair_model( model_cls, len(vocabulary), FLAGS.seq_length, X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs, initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings), ss_mask_gen=ss_mask_gen, ss_prob=ss_prob) else: X = T.matrix("X", dtype="int32") transitions = T.imatrix("transitions") num_transitions = T.vector("num_transitions", dtype="int32") predicted_transitions, logits = build_sentence_model( model_cls, len(vocabulary), FLAGS.seq_length, X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs, initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings), ss_mask_gen=ss_mask_gen, ss_prob=ss_prob) xent_cost, acc = build_cost(logits, y) # Set up L2 regularization. l2_cost = 0.0 for var in vs.trainable_vars: l2_cost += FLAGS.l2_lambda * T.sum(T.sqr(vs.vars[var])) # Compute cross-entropy cost on action predictions. if (not data_manager.SENTENCE_PAIR_DATA) and FLAGS.model_type not in [ "Model0", "RNN", "CBOW" ]: transition_cost, action_acc = build_transition_cost( predicted_transitions, transitions, num_transitions) elif data_manager.SENTENCE_PAIR_DATA and FLAGS.model_type not in [ "Model0", "RNN", "CBOW" ]: p_transition_cost, p_action_acc = build_transition_cost( predicted_premise_transitions, transitions[:, :, 0], num_transitions[:, 0]) h_transition_cost, h_action_acc = build_transition_cost( predicted_hypothesis_transitions, transitions[:, :, 1], num_transitions[:, 1]) transition_cost = p_transition_cost + h_transition_cost action_acc = (p_action_acc + h_action_acc ) / 2.0 # TODO(SB): Average over transitions, not words. else: transition_cost = T.constant(0.0) action_acc = T.constant(0.0) transition_cost = transition_cost * FLAGS.transition_cost_scale total_cost = xent_cost + l2_cost + transition_cost if ".ckpt" in FLAGS.ckpt_path: checkpoint_path = FLAGS.ckpt_path else: checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + ".ckpt") if os.path.isfile(checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = vs.load_checkpoint( checkpoint_path, num_extra_vars=2, skip_saved_unsavables=FLAGS.skip_saved_unsavables) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # Do an evaluation-only run. if only_forward: if FLAGS.eval_output_paths: eval_output_paths = FLAGS.eval_output_paths.strip().split(":") assert len(eval_output_paths) == len( eval_iterators), "Invalid no. of output paths." else: eval_output_paths = [ FLAGS.experiment_name + "-" + os.path.split(eval_set[0])[1] + "-parse" for eval_set in eval_iterators ] # Load model from checkpoint. logger.Log("Checkpointed model was trained for %d steps." % (step, )) # Generate function for forward pass. logger.Log("Building forward pass.") if data_manager.SENTENCE_PAIR_DATA: eval_fn = theano.function([ X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob ], [ acc, action_acc, logits, predicted_hypothesis_transitions, predicted_premise_transitions ], on_unused_input='ignore', allow_input_downcast=True) else: eval_fn = theano.function([ X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob ], [acc, action_acc, logits, predicted_transitions], on_unused_input='ignore', allow_input_downcast=True) # Generate the inverse vocabulary lookup table. ind_to_word = {v: k for k, v in vocabulary.iteritems()} # Do a forward pass and write the output to disk. for eval_set, eval_out_path in zip(eval_iterators, eval_output_paths): logger.Log("Writing eval output for %s." % (eval_set[0], )) evaluate_expanded(eval_fn, eval_set, eval_out_path, logger, step, data_manager.SENTENCE_PAIR_DATA, ind_to_word, FLAGS.model_type not in ["Model0", "RNN"]) else: # Train new_values = util.RMSprop(total_cost, vs.trainable_vars.values(), lr) new_values += [(key, vs.nongradient_updates[key]) for key in vs.nongradient_updates] # Training open-vocabulary embeddings is a questionable idea right now. Disabled: # new_values.append( # util.embedding_SGD(total_cost, embedding_params, embedding_lr)) # Create training and eval functions. # Unused variable warnings are supressed so that num_transitions can be passed in when training Model 0, # which ignores it. This yields more readable code that is very slightly slower. logger.Log("Building update function.") update_fn = theano.function([ X, transitions, y, num_transitions, lr, training_mode, ground_truth_transitions_visible, ss_prob ], [total_cost, xent_cost, transition_cost, action_acc, l2_cost, acc], updates=new_values, on_unused_input='ignore', allow_input_downcast=True) logger.Log("Building eval function.") eval_fn = theano.function([ X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob ], [acc, action_acc], on_unused_input='ignore', allow_input_downcast=True) logger.Log("Training.") # Main training loop. for step in range(step, FLAGS.training_steps): if step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(eval_fn, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > 1000: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) vs.save_checkpoint(checkpoint_path + "_best", extra_vars=[step, best_dev_error]) X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next( ) learning_rate = FLAGS.learning_rate * ( FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0)) ret = update_fn( X_batch, transitions_batch, y_batch, num_transitions_batch, learning_rate, 1.0, 1.0, np.exp(step * np.log(FLAGS.scheduled_sampling_exponent_base))) total_cost_val, xent_cost_val, transition_cost_val, action_acc_val, l2_cost_val, acc_val = ret if step % FLAGS.statistics_interval_steps == 0: logger.Log("Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %5f" % (step, acc_val, action_acc_val, total_cost_val, xent_cost_val, transition_cost_val, l2_cost_val)) if step % FLAGS.ckpt_interval_steps == 0 and step > 0: vs.save_checkpoint(checkpoint_path, extra_vars=[step, best_dev_error])
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path, eval_data_path): choose_train = lambda x: True if FLAGS.train_genre is not None: choose_train = lambda x: x.get('genre') == FLAGS.train_genre choose_eval = lambda x: True if FLAGS.eval_genre is not None: choose_eval = lambda x: x.get('genre') == FLAGS.eval_genre if FLAGS.data_type == "snli": # Load the data. raw_training_data, vocabulary = data_manager.load_data( training_data_path, FLAGS.lowercase, choose_train) else: # Load the data. raw_training_data, vocabulary = data_manager.load_data( training_data_path, FLAGS.lowercase) if FLAGS.data_type == "snli": # Load the eval data. raw_eval_sets = [] raw_eval_data, _ = data_manager.load_data(eval_data_path, FLAGS.lowercase, choose_eval) raw_eval_sets.append((eval_data_path, raw_eval_data)) else: # Load the eval data. raw_eval_sets = [] raw_eval_data, _ = data_manager.load_data(eval_data_path, FLAGS.lowercase) raw_eval_sets.append((eval_data_path, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) return vocabulary, initial_embeddings, training_data_iter, eval_iterators
def run(only_forward=False): logger = afs_safe_logger.Logger( os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): eval_data, _ = data_manager.load_data(eval_filename) raw_eval_sets.append((eval_filename, eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromASCII( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") training_data_iter = util.MakeTrainingIterator(training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano) eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") eval_iterators.append( (filename, util.MakeEvalIterator( (e_X, e_transitions, e_y, e_num_transitions), FLAGS.batch_size, FLAGS.eval_data_limit))) # Set up the placeholders. logger.Log("Building model.") if FLAGS.model_type == "CBOW": model_module = spinn.cbow elif FLAGS.model_type == "RNN": model_module = spinn.plain_rnn_chainer elif FLAGS.model_type == "NTI": model_module = spinn.nti elif FLAGS.model_type == "SPINN": model_module = spinn.fat_stack else: raise Exception("Requested unimplemented model type %s" % FLAGS.model_type) if data_manager.SENTENCE_PAIR_DATA: if hasattr(model_module, 'SentencePairTrainer') and hasattr( model_module, 'SentencePairModel'): trainer_cls = model_module.SentencePairTrainer model_cls = model_module.SentencePairModel else: raise Exception("Unimplemented for model type %s" % FLAGS.model_type) num_classes = len(data_manager.LABEL_MAP) classifier_trainer = build_sentence_pair_model( model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim, FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes, initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu) else: if hasattr(model_module, 'SentenceTrainer') and hasattr( model_module, 'SentenceModel'): trainer_cls = model_module.SentenceTrainer model_cls = model_module.SentenceModel else: raise Exception("Unimplemented for model type %s" % FLAGS.model_type) num_classes = len(data_manager.LABEL_MAP) classifier_trainer = build_sentence_pair_model( model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim, FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes, initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu) if ".ckpt" in FLAGS.ckpt_path: checkpoint_path = FLAGS.ckpt_path else: checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + ".ckpt") if os.path.isfile(checkpoint_path): # TODO: Check that resuming works fine with tf summaries. logger.Log("Found checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format( step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 if FLAGS.write_summaries: from spinn.tf_logger import TFLogger train_summary_logger = TFLogger(summary_dir=os.path.join( FLAGS.summary_dir, FLAGS.experiment_name, 'train')) dev_summary_logger = TFLogger(summary_dir=os.path.join( FLAGS.summary_dir, FLAGS.experiment_name, 'dev')) # Do an evaluation-only run. if only_forward: raise Exception("Not implemented for chainer.") else: # Train logger.Log("Training.") classifier_trainer.init_optimizer( clip=FLAGS.clipping_max_value, decay=FLAGS.l2_lambda, lr=FLAGS.learning_rate, ) # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) avg_class_acc = 0 avg_trans_acc = 0 for step in range(step, FLAGS.training_steps): X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next( ) learning_rate = FLAGS.learning_rate * ( FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0)) # Reset cached gradients. classifier_trainer.optimizer.zero_grads() # Calculate loss and update parameters. ret = classifier_trainer.forward( { "sentences": X_batch, "transitions": transitions_batch, }, y_batch, train=True, predict=False) y, loss, class_acc, transition_acc = ret # Boilerplate for calculating loss. xent_cost_val = loss.data transition_cost_val = 0 avg_trans_acc += transition_acc avg_class_acc += class_acc if FLAGS.show_intermediate_stats and step % 5 == 0 and step % FLAGS.statistics_interval_steps > 0: print("Accuracies so far : ", avg_class_acc / (step % FLAGS.statistics_interval_steps), avg_trans_acc / (step % FLAGS.statistics_interval_steps)) total_cost_val = xent_cost_val + transition_cost_val loss.backward() if FLAGS.gradient_check: def get_loss(): _, check_loss, _, _ = classifier_trainer.forward( { "sentences": X_batch, "transitions": transitions_batch, }, y_batch, train=True, predict=False) return check_loss gradient_check(classifier_trainer.model, get_loss) try: classifier_trainer.update() except: import ipdb ipdb.set_trace() pass # Accumulate accuracy for current interval. action_acc_val = 0.0 acc_val = float(classifier_trainer.model.accuracy.data) if FLAGS.write_summaries: train_summary_logger.log(step=step, loss=total_cost_val, accuracy=acc_val) progress_bar.step( i=max(0, step - 1) % FLAGS.statistics_interval_steps + 1, total=FLAGS.statistics_interval_steps) if step % FLAGS.statistics_interval_steps == 0: progress_bar.finish() avg_class_acc /= FLAGS.statistics_interval_steps avg_trans_acc /= FLAGS.statistics_interval_steps logger.Log( "Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %s" % (step, avg_class_acc, avg_trans_acc, total_cost_val, xent_cost_val, transition_cost_val, "l2-not-exposed")) avg_trans_acc = 0 avg_class_acc = 0 if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(classifier_trainer, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > 1000: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) classifier_trainer.save(checkpoint_path, step, best_dev_error) if FLAGS.write_summaries: dev_summary_logger.log(step=step, loss=0.0, accuracy=acc) progress_bar.reset() if FLAGS.profile and step >= FLAGS.profile_steps: break
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data elif FLAGS.data_type == "arithmetic": data_manager = load_simple_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Make Metrics Logger. metrics_path = "{}/{}".format(FLAGS.metrics_path, FLAGS.experiment_name) if not os.path.exists(metrics_path): os.makedirs(metrics_path) metrics_logger = MetricsLogger(metrics_path) M = Accumulator(maxlen=FLAGS.deque_length) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only(), use_left_padding=FLAGS.use_left_padding) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only(), use_left_padding=FLAGS.use_left_padding) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Choose model. model_specific_params = {} logger.Log("Building model.") if FLAGS.model_type == "CBOW": model_module = spinn.cbow elif FLAGS.model_type == "RNN": model_module = spinn.plain_rnn elif FLAGS.model_type == "SPINN": model_module = spinn.fat_stack elif FLAGS.model_type == "RLSPINN": model_module = spinn.rl_spinn elif FLAGS.model_type == "RAESPINN": model_module = spinn.rae_spinn elif FLAGS.model_type == "GENSPINN": model_module = spinn.gen_spinn elif FLAGS.model_type == "ATTSPINN": model_module = spinn.att_spinn model_specific_params['using_diff_in_mlstm'] = FLAGS.using_diff_in_mlstm model_specific_params['using_prod_in_mlstm'] = FLAGS.using_prod_in_mlstm model_specific_params['using_null_in_attention'] = FLAGS.using_null_in_attention attlogger = logging.getLogger('spinn.attention') attlogger.setLevel(logging.INFO) fh = logging.FileHandler(os.path.join(FLAGS.log_path, '{}.att.log'.format(FLAGS.experiment_name))) fh.setLevel(logging.INFO) fh.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s')) attlogger.addHandler(fh) else: raise Exception("Requested unimplemented model type %s" % FLAGS.model_type) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) if data_manager.SENTENCE_PAIR_DATA: trainer_cls = model_module.SentencePairTrainer model_cls = model_module.SentencePairModel use_sentence_pair = True else: trainer_cls = model_module.SentenceTrainer model_cls = model_module.SentenceModel num_classes = len(data_manager.LABEL_MAP) use_sentence_pair = False model = model_cls(model_dim=FLAGS.model_dim, word_embedding_dim=FLAGS.word_embedding_dim, vocab_size=vocab_size, initial_embeddings=initial_embeddings, num_classes=num_classes, mlp_dim=FLAGS.mlp_dim, embedding_keep_rate=FLAGS.embedding_keep_rate, classifier_keep_rate=FLAGS.semantic_classifier_keep_rate, tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim, transition_weight=FLAGS.transition_weight, encode_style=FLAGS.encode_style, encode_reverse=FLAGS.encode_reverse, encode_bidirectional=FLAGS.encode_bidirectional, encode_num_layers=FLAGS.encode_num_layers, use_sentence_pair=use_sentence_pair, use_skips=FLAGS.use_skips, lateral_tracking=FLAGS.lateral_tracking, use_tracking_in_composition=FLAGS.use_tracking_in_composition, use_difference_feature=FLAGS.use_difference_feature, use_product_feature=FLAGS.use_product_feature, num_mlp_layers=FLAGS.num_mlp_layers, mlp_bn=FLAGS.mlp_bn, rl_mu=FLAGS.rl_mu, rl_baseline=FLAGS.rl_baseline, rl_reward=FLAGS.rl_reward, rl_weight=FLAGS.rl_weight, predict_leaf=FLAGS.predict_leaf, gen_h=FLAGS.gen_h, model_specific_params=model_specific_params, ) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. classifier_trainer = trainer_cls(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # Print model size. logger.Log("Architecture: {}".format(model)) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, metrics_logger, step, vocabulary) else: # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = training_data_iter.next() if FLAGS.truncate_train_batch: X_batch, transitions_batch = truncate( X_batch, transitions_batch, num_transitions_batch) total_tokens = num_transitions_batch.ravel().sum() # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) A.add('class_acc', class_acc) M.add('class_acc', class_acc) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss/accuracy. transition_acc = model.transition_acc if hasattr(model, 'transition_acc') else 0.0 transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None rl_loss = model.rl_loss if hasattr(model, 'rl_loss') else None policy_loss = model.policy_loss if hasattr(model, 'policy_loss') else None rae_loss = model.spinn.rae_loss if hasattr(model.spinn, 'rae_loss') else None leaf_loss = model.spinn.leaf_loss if hasattr(model.spinn, 'leaf_loss') else None gen_loss = model.spinn.gen_loss if hasattr(model.spinn, 'gen_loss') else None # Force Transition Loss Optimization if FLAGS.force_transition_loss: model.optimize_transition_loss = True # Accumulate stats for transition accuracy. if transition_loss is not None: preds = [m["t_preds"] for m in model.spinn.memories] truth = [m["t_given"] for m in model.spinn.memories] A.add('preds', preds) A.add('truth', truth) # Accumulate stats for leaf prediction accuracy. if leaf_loss is not None: A.add('leaf_acc', model.spinn.leaf_acc) # Accumulate stats for word prediction accuracy. if gen_loss is not None: A.add('gen_acc', model.spinn.gen_acc) # Note: Keep track of transition_acc, although this is a naive average. # Should be weighted by length of sequences in batch. M.add('transition_acc', transition_acc) # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Boilerplate for calculating loss values. xent_cost_val = xent_loss.data[0] transition_cost_val = transition_loss.data[0] if transition_loss is not None else 0.0 l2_cost_val = l2_loss.data[0] if l2_loss is not None else 0.0 rl_cost_val = rl_loss.data[0] if rl_loss is not None else 0.0 policy_cost_val = policy_loss.data[0] if policy_loss is not None else 0.0 rae_cost_val = rae_loss.data[0] if rae_loss is not None else 0.0 leaf_cost_val = leaf_loss.data[0] if leaf_loss is not None else 0.0 gen_cost_val = gen_loss.data[0] if gen_loss is not None else 0.0 # Accumulate Total Loss Data total_cost_val = 0.0 total_cost_val += xent_cost_val if transition_loss is not None and model.optimize_transition_loss: total_cost_val += transition_cost_val total_cost_val += l2_cost_val total_cost_val += rl_cost_val total_cost_val += policy_cost_val total_cost_val += rae_cost_val total_cost_val += leaf_cost_val total_cost_val += gen_cost_val M.add('total_cost', total_cost_val) M.add('xent_cost', xent_cost_val) M.add('transition_cost', transition_cost_val) M.add('l2_cost', l2_cost_val) # Logging for RL rl_keys = ['rl_loss', 'policy_loss', 'norm_rewards', 'norm_baseline', 'norm_advantage'] for k in rl_keys: if hasattr(model, k): val = getattr(model, k) val = val.data[0] if isinstance(val, Variable) else val M.add(k, val) # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss if rl_loss is not None: total_loss += rl_loss if policy_loss is not None: total_loss += policy_loss if rae_loss is not None: total_loss += rae_loss if leaf_loss is not None: total_loss += leaf_loss if gen_loss is not None: total_loss += gen_loss # Useful for debugging gradient flow. if FLAGS.debug: losses = [('total_loss', total_loss), ('xent_loss', xent_loss)] if l2_loss is not None: losses.append(('l2_loss', l2_loss)) if transition_loss is not None and model.optimize_transition_loss: losses.append(('transition_loss', transition_loss)) if rl_loss is not None: losses.append(('rl_loss', rl_loss)) if policy_loss is not None: losses.append(('policy_loss', policy_loss)) debug_gradient(model, losses) import ipdb; ipdb.set_trace() # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() avg_class_acc = A.get_avg('class_acc') if transition_loss is not None: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: avg_trans_acc = 0.0 if leaf_loss is not None: avg_leaf_acc = A.get_avg('leaf_acc') else: avg_leaf_acc = 0.0 if gen_loss is not None: avg_gen_acc = A.get_avg('gen_acc') else: avg_gen_acc = 0.0 time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) stats_args = { "step": step, "class_acc": avg_class_acc, "transition_acc": avg_trans_acc, "total_cost": total_cost_val, "xent_cost": xent_cost_val, "transition_cost": transition_cost_val, "l2_cost": l2_cost_val, "rl_cost": rl_cost_val, "policy_cost": policy_cost_val, "rae_cost": rae_cost_val, "leaf_acc": avg_leaf_acc, "leaf_cost": leaf_cost_val, "gen_acc": avg_gen_acc, "gen_cost": gen_cost_val, "time": time_metric, } stats_str = "Step: {step}" # Accuracy Component. stats_str += " Acc: {class_acc:.5f} {transition_acc:.5f}" if leaf_loss is not None: stats_str += " leaf{leaf_acc:.5f}" if gen_loss is not None: stats_str += " gen{gen_acc:.5f}" # Cost Component. stats_str += " Cost: {total_cost:.5f} {xent_cost:.5f} {transition_cost:.5f} {l2_cost:.5f}" if rl_loss is not None: stats_str += " r{rl_cost:.5f}" if policy_loss is not None: stats_str += " p{policy_cost:.5f}" if rae_loss is not None: stats_str += " rae{rae_cost:.5f}" if leaf_loss is not None: stats_str += " leaf{leaf_cost:.5f}" if gen_loss is not None: stats_str += " gen{gen_cost:.5f}" # Time Component. stats_str += " Time: {time:.5f}" logger.Log(stats_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, metrics_logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) classifier_trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") classifier_trainer.save(standard_checkpoint_path, step, best_dev_error) if step % FLAGS.metrics_interval_steps == 0: m_keys = M.cache.keys() for k in m_keys: metrics_logger.Log(k, M.get_avg(k), step) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager) # Build trainer. trainer = ModelTrainer(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), the_gpu.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step, vocabulary) else: # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4] model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4] total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type == "RLSPINN": model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay) # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)