def test_load_embeddings_from_ascii(): vocabulary = {"strange_and_exotic_word": 0, "the": 1, ".": 2} loaded_matrix = util.LoadEmbeddingsFromASCII(vocabulary, 5, TEST_EMBEDDING_MATRIX) expected = np.asarray( [[0, 0, 0, 0, 0], [0.418, 0.24968, -0.41242, 0.1217, 0.34527], [0.15164, 0.30177, -0.16763, 0.17684, 0.31719]], dtype=np.float32) np.testing.assert_array_equal(loaded_matrix, expected)
def run(only_forward=False): logger = afs_safe_logger.Logger( os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data else: logger.Log("Bad data type.") return if FLAGS.model_type != "Model0": raise NotImplementedError("Only basic model 0 (SPINN-PI, SPINN-PI-NT) " "is supported in the thin-stack " "implementation.") pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): eval_data, _ = data_manager.load_data(eval_filename) raw_eval_sets.append((eval_filename, eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromASCII( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) training_data_iter = util.MakeTrainingIterator(training_data, FLAGS.batch_size) eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) eval_iterators.append( (filename, util.MakeEvalIterator( (e_X, e_transitions, e_y, e_num_transitions), FLAGS.batch_size))) # Set up the placeholders. y = T.vector("y", dtype="int32") lr = T.scalar("lr") training_mode = T.scalar( "training_mode") # 1: Training with dropout, 0: Eval ground_truth_transitions_visible = T.scalar( "ground_truth_transitions_visible", dtype="int32") logger.Log("Building model.") vs = util.VariableStore(default_initializer=util.UniformInitializer( FLAGS.init_range), logger=logger) if FLAGS.model_type == "RNN": model_cls = spinn.plain_rnn.RNN else: model_cls = getattr(recurrences, FLAGS.model_type) # Generator of mask for scheduled sampling numpy_random = np.random.RandomState(1234) ss_mask_gen = T.shared_randomstreams.RandomStreams( numpy_random.randint(999999)) # Training step number ss_prob = T.scalar("ss_prob") if data_manager.SENTENCE_PAIR_DATA: X = T.itensor3("X") transitions = T.itensor3("transitions") num_transitions = T.imatrix("num_transitions") premise_model, hypothesis_model, logits, zero_fn = build_sentence_pair_model( model_cls, len(vocabulary), FLAGS.seq_length, X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs, initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings), ss_mask_gen=ss_mask_gen, ss_prob=ss_prob) premise_stack_top = premise_model.sentence_embeddings hypothesis_stack_top = hypothesis_model.sentence_embeddings predicted_premise_transitions = premise_model.transitions_pred predicted_hypothesis_transitions = hypothesis_model.transitions_pred else: X = T.matrix("X", dtype="int32") transitions = T.imatrix("transitions") num_transitions = T.vector("num_transitions", dtype="int32") model, logits, zero_fn = build_sentence_model( model_cls, len(vocabulary), FLAGS.seq_length, X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs, initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings), ss_mask_gen=ss_mask_gen, ss_prob=ss_prob) stack_top = model.sentence_embeddings predicted_transitions = model.transitions_pred xent_cost, acc = build_cost(logits, y) # Set up L2 regularization. l2_cost = 0.0 for var in vs.trainable_vars: l2_cost += FLAGS.l2_lambda * T.sum(T.sqr(vs.vars[var])) # Compute cross-entropy cost on action predictions. if (not data_manager.SENTENCE_PAIR_DATA ) and predicted_transitions is not None: transition_cost, action_acc = build_transition_cost( predicted_transitions, transitions, num_transitions) elif data_manager.SENTENCE_PAIR_DATA and predicted_hypothesis_transitions is not None: p_transition_cost, p_action_acc = build_transition_cost( predicted_premise_transitions, transitions[:, :, 0], num_transitions[:, 0]) h_transition_cost, h_action_acc = build_transition_cost( predicted_hypothesis_transitions, transitions[:, :, 1], num_transitions[:, 1]) transition_cost = p_transition_cost + h_transition_cost action_acc = (p_action_acc + h_action_acc ) / 2.0 # TODO(SB): Average over transitions, not words. else: transition_cost = T.constant(0.0) action_acc = T.constant(0.0) transition_cost = transition_cost * FLAGS.transition_cost_scale total_cost = xent_cost + l2_cost + transition_cost if ".ckpt" in FLAGS.ckpt_path: checkpoint_path = FLAGS.ckpt_path else: checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + ".ckpt") if os.path.isfile(checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = vs.load_checkpoint( checkpoint_path, num_extra_vars=2, skip_saved_unsavables=FLAGS.skip_saved_unsavables) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # Do an evaluation-only run. if only_forward: if FLAGS.eval_output_paths: eval_output_paths = FLAGS.eval_output_paths.strip().split(":") assert len(eval_output_paths) == len( eval_iterators), "Invalid no. of output paths." else: eval_output_paths = [ FLAGS.experiment_name + "-" + os.path.split(eval_set[0])[1] + "-parse" for eval_set in eval_iterators ] # Load model from checkpoint. logger.Log("Checkpointed model was trained for %d steps." % (step, )) # Generate function for forward pass. logger.Log("Building forward pass.") if data_manager.SENTENCE_PAIR_DATA: eval_fn = theano.function([ X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob ], [ acc, action_acc, logits, predicted_hypothesis_transitions, predicted_premise_transitions ], on_unused_input='warn', allow_input_downcast=True) else: eval_fn = theano.function([ X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob ], [acc, action_acc, logits, predicted_transitions], on_unused_input='warn', allow_input_downcast=True) # Generate the inverse vocabulary lookup table. ind_to_word = {v: k for k, v in vocabulary.iteritems()} # Do a forward pass and write the output to disk. for eval_set, eval_out_path in zip(eval_iterators, eval_output_paths): logger.Log("Writing eval output for %s." % (eval_set[0], )) evaluate_expanded(eval_fn, eval_set, eval_out_path, logger, step, data_manager.SENTENCE_PAIR_DATA, ind_to_word, zero_fn) else: # Train extra_cost_inputs = [ y, training_mode, ground_truth_transitions_visible ] if data_manager.SENTENCE_PAIR_DATA: # The two models use slices of the original data. # Pass the original data as a non-sequence input as well. extra_cost_inputs += [X, transitions] premise_error_signal = T.grad(total_cost, premise_stack_top) premise_model.make_backprop_scan( premise_error_signal, extra_cost_inputs=extra_cost_inputs, compute_embedding_gradients=False) extra_cost_inputs += [premise_model.stack ] + premise_model.aux_stacks hypothesis_error_signal = T.grad(total_cost, hypothesis_stack_top) hypothesis_model.make_backprop_scan( hypothesis_error_signal, extra_cost_inputs=extra_cost_inputs, compute_embedding_gradients=False) gradients = premise_model.gradients hypothesis_gradients = hypothesis_model.gradients for key in hypothesis_gradients: if key in gradients: gradients[key] += hypothesis_gradients[key] else: gradients[key] = hypothesis_gradients[key] new_values = util.merge_updates( premise_model.scan_updates + premise_model.bscan_updates, hypothesis_model.scan_updates + hypothesis_model.bscan_updates).items() other_params = set(vs.trainable_vars.keys()) - premise_model._vars other_params -= hypothesis_model._vars else: error_signal = T.grad(total_cost, stack_top) model.make_backprop_scan( error_signal, extra_cost_inputs=extra_cost_inputs, compute_embedding_gradients=train_embeddings) if train_embeddings: model.gradients[model.embeddings] = model.embedding_gradients gradients = model.gradients new_values = model.scan_updates.items( ) + model.bscan_updates.items() other_params = set(vs.trainable_vars.keys()) - model._vars # Remove null stack parameter gradients. null_gradients = set() for key, val in gradients.iteritems(): if val is None: null_gradients.add(key) if null_gradients: logger.Log( "The following parameters have null (disconnected) cost " "gradients and will not be trained: %s" % ", ".join(str(k) for k in null_gradients), logger.WARNING) for key in null_gradients: del gradients[key] # Calculate gradients for items before/after stack fprop. other_params = [vs.vars[param] for param in other_params] other_grads = T.grad(total_cost, wrt=other_params) gradients.update(zip(other_params, other_grads)) new_values += util.RMSprop(total_cost, gradients.keys(), lr, grads=gradients.values()) new_values += [(key, vs.nongradient_updates[key]) for key in vs.nongradient_updates] # Create training and eval functions. # Unused variable warnings are supressed so that num_transitions can be passed in when training Model 0, # which ignores it. This yields more readable code that is very slightly slower. logger.Log("Building update function.") update_fn = theano.function([ X, transitions, y, num_transitions, lr, training_mode, ground_truth_transitions_visible, ss_prob ], [total_cost, xent_cost, transition_cost, action_acc, l2_cost, acc], updates=new_values, on_unused_input='warn', allow_input_downcast=True) logger.Log("Building eval function.") eval_fn = theano.function([ X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob ], [acc, action_acc], on_unused_input='warn', allow_input_downcast=True) logger.Log("Training.") # Main training loop. for step in range(step, FLAGS.training_steps): if step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(eval_fn, eval_set, logger, step, zero_fn) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > 1000: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) vs.save_checkpoint(checkpoint_path + "_best", extra_vars=[step, best_dev_error]) X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next( ) # HACK: Drop training batches which aren't well-sized. (Will only # trigger for the final batch in a dataset.) if X_batch.shape[0] != FLAGS.batch_size: continue learning_rate = FLAGS.learning_rate * ( FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0)) ret = update_fn( X_batch, transitions_batch, y_batch, num_transitions_batch, learning_rate, 1.0, 1.0, np.exp(step * np.log(FLAGS.scheduled_sampling_exponent_base))) total_cost_val, xent_cost_val, transition_cost_val, action_acc_val, l2_cost_val, acc_val = ret if step % FLAGS.statistics_interval_steps == 0: logger.Log("Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %5f" % (step, acc_val, action_acc_val, total_cost_val, xent_cost_val, transition_cost_val, l2_cost_val)) if step % FLAGS.ckpt_interval_steps == 0 and step > 0: vs.save_checkpoint(checkpoint_path, extra_vars=[step, best_dev_error]) # Zero out all auxiliary variables. zero_fn()
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): eval_data, _ = data_manager.load_data(eval_filename) raw_eval_sets.append((eval_filename, eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromASCII( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size) eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") eval_iterators.append((filename, util.MakeEvalIterator((e_X, e_transitions, e_y, e_num_transitions), FLAGS.batch_size))) # Set up the placeholders. y = T.vector("y", dtype="int32") lr = T.scalar("lr") training_mode = T.scalar("training_mode") # 1: Training with dropout, 0: Eval ground_truth_transitions_visible = T.scalar("ground_truth_transitions_visible", dtype="int32") logger.Log("Building model.") vs = util.VariableStore( default_initializer=util.UniformInitializer(FLAGS.init_range), logger=logger) if FLAGS.model_type == "CBOW": model_cls = spinn.cbow.CBOW elif FLAGS.model_type == "RNN": model_cls = spinn.plain_rnn.RNN else: model_cls = getattr(spinn.fat_stack, FLAGS.model_type) # Generator of mask for scheduled sampling numpy_random = np.random.RandomState(1234) ss_mask_gen = T.shared_randomstreams.RandomStreams(numpy_random.randint(999999)) # Training step number ss_prob = T.scalar("ss_prob") if data_manager.SENTENCE_PAIR_DATA: X = T.itensor3("X") transitions = T.itensor3("transitions") num_transitions = T.imatrix("num_transitions") predicted_premise_transitions, predicted_hypothesis_transitions, logits = build_sentence_pair_model( model_cls, len(vocabulary), FLAGS.seq_length, X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs, initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings), ss_mask_gen=ss_mask_gen, ss_prob=ss_prob) else: X = T.matrix("X", dtype="int32") transitions = T.imatrix("transitions") num_transitions = T.vector("num_transitions", dtype="int32") predicted_transitions, logits = build_sentence_model( model_cls, len(vocabulary), FLAGS.seq_length, X, transitions, len(data_manager.LABEL_MAP), training_mode, ground_truth_transitions_visible, vs, initial_embeddings=initial_embeddings, project_embeddings=(not train_embeddings), ss_mask_gen=ss_mask_gen, ss_prob=ss_prob) xent_cost, acc = build_cost(logits, y) # Set up L2 regularization. l2_cost = 0.0 for var in vs.trainable_vars: l2_cost += FLAGS.l2_lambda * T.sum(T.sqr(vs.vars[var])) # Compute cross-entropy cost on action predictions. if (not data_manager.SENTENCE_PAIR_DATA) and FLAGS.model_type not in ["Model0", "RNN", "CBOW"]: transition_cost, action_acc = build_transition_cost(predicted_transitions, transitions, num_transitions) elif data_manager.SENTENCE_PAIR_DATA and FLAGS.model_type not in ["Model0", "RNN", "CBOW"]: p_transition_cost, p_action_acc = build_transition_cost(predicted_premise_transitions, transitions[:, :, 0], num_transitions[:, 0]) h_transition_cost, h_action_acc = build_transition_cost(predicted_hypothesis_transitions, transitions[:, :, 1], num_transitions[:, 1]) transition_cost = p_transition_cost + h_transition_cost action_acc = (p_action_acc + h_action_acc) / 2.0 # TODO(SB): Average over transitions, not words. else: transition_cost = T.constant(0.0) action_acc = T.constant(0.0) transition_cost = transition_cost * FLAGS.transition_cost_scale total_cost = xent_cost + l2_cost + transition_cost if ".ckpt" in FLAGS.ckpt_path: checkpoint_path = FLAGS.ckpt_path else: checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + ".ckpt") if os.path.isfile(checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = vs.load_checkpoint(checkpoint_path, num_extra_vars=2, skip_saved_unsavables=FLAGS.skip_saved_unsavables) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # Do an evaluation-only run. if only_forward: if FLAGS.eval_output_paths: eval_output_paths = FLAGS.eval_output_paths.strip().split(":") assert len(eval_output_paths) == len(eval_iterators), "Invalid no. of output paths." else: eval_output_paths = [FLAGS.experiment_name + "-" + os.path.split(eval_set[0])[1] + "-parse" for eval_set in eval_iterators] # Load model from checkpoint. logger.Log("Checkpointed model was trained for %d steps." % (step,)) # Generate function for forward pass. logger.Log("Building forward pass.") if data_manager.SENTENCE_PAIR_DATA: eval_fn = theano.function( [X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob], [acc, action_acc, logits, predicted_hypothesis_transitions, predicted_premise_transitions], on_unused_input='ignore', allow_input_downcast=True) else: eval_fn = theano.function( [X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob], [acc, action_acc, logits, predicted_transitions], on_unused_input='ignore', allow_input_downcast=True) # Generate the inverse vocabulary lookup table. ind_to_word = {v : k for k, v in vocabulary.iteritems()} # Do a forward pass and write the output to disk. for eval_set, eval_out_path in zip(eval_iterators, eval_output_paths): logger.Log("Writing eval output for %s." % (eval_set[0],)) evaluate_expanded(eval_fn, eval_set, eval_out_path, logger, step, data_manager.SENTENCE_PAIR_DATA, ind_to_word, FLAGS.model_type not in ["Model0", "RNN", "CBOW"]) else: # Train new_values = util.RMSprop(total_cost, vs.trainable_vars.values(), lr) new_values += [(key, vs.nongradient_updates[key]) for key in vs.nongradient_updates] # Training open-vocabulary embeddings is a questionable idea right now. Disabled: # new_values.append( # util.embedding_SGD(total_cost, embedding_params, embedding_lr)) # Create training and eval functions. # Unused variable warnings are supressed so that num_transitions can be passed in when training Model 0, # which ignores it. This yields more readable code that is very slightly slower. logger.Log("Building update function.") update_fn = theano.function( [X, transitions, y, num_transitions, lr, training_mode, ground_truth_transitions_visible, ss_prob], [total_cost, xent_cost, transition_cost, action_acc, l2_cost, acc], updates=new_values, on_unused_input='ignore', allow_input_downcast=True) logger.Log("Building eval function.") eval_fn = theano.function([X, transitions, y, num_transitions, training_mode, ground_truth_transitions_visible, ss_prob], [acc, action_acc], on_unused_input='ignore', allow_input_downcast=True) logger.Log("Training.") # Main training loop. for step in range(step, FLAGS.training_steps): if step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(eval_fn, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > 1000: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) vs.save_checkpoint(checkpoint_path + "_best", extra_vars=[step, best_dev_error]) X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next() learning_rate = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) ret = update_fn(X_batch, transitions_batch, y_batch, num_transitions_batch, learning_rate, 1.0, 1.0, np.exp(step*np.log(FLAGS.scheduled_sampling_exponent_base))) total_cost_val, xent_cost_val, transition_cost_val, action_acc_val, l2_cost_val, acc_val = ret if step % FLAGS.statistics_interval_steps == 0: logger.Log( "Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %5f" % (step, acc_val, action_acc_val, total_cost_val, xent_cost_val, transition_cost_val, l2_cost_val)) if step % FLAGS.ckpt_interval_steps == 0 and step > 0: vs.save_checkpoint(checkpoint_path, extra_vars=[step, best_dev_error])
def run(only_forward=False): logger = afs_safe_logger.Logger( os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): eval_data, _ = data_manager.load_data(eval_filename) raw_eval_sets.append((eval_filename, eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromASCII( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") training_data_iter = util.MakeTrainingIterator(training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano) eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) e_X, e_transitions, e_y, e_num_transitions = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=FLAGS.model_type == "RNN" or FLAGS.model_type == "CBOW") eval_iterators.append( (filename, util.MakeEvalIterator( (e_X, e_transitions, e_y, e_num_transitions), FLAGS.batch_size, FLAGS.eval_data_limit))) # Set up the placeholders. logger.Log("Building model.") if FLAGS.model_type == "CBOW": model_module = spinn.cbow elif FLAGS.model_type == "RNN": model_module = spinn.plain_rnn_chainer elif FLAGS.model_type == "NTI": model_module = spinn.nti elif FLAGS.model_type == "SPINN": model_module = spinn.fat_stack else: raise Exception("Requested unimplemented model type %s" % FLAGS.model_type) if data_manager.SENTENCE_PAIR_DATA: if hasattr(model_module, 'SentencePairTrainer') and hasattr( model_module, 'SentencePairModel'): trainer_cls = model_module.SentencePairTrainer model_cls = model_module.SentencePairModel else: raise Exception("Unimplemented for model type %s" % FLAGS.model_type) num_classes = len(data_manager.LABEL_MAP) classifier_trainer = build_sentence_pair_model( model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim, FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes, initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu) else: if hasattr(model_module, 'SentenceTrainer') and hasattr( model_module, 'SentenceModel'): trainer_cls = model_module.SentenceTrainer model_cls = model_module.SentenceModel else: raise Exception("Unimplemented for model type %s" % FLAGS.model_type) num_classes = len(data_manager.LABEL_MAP) classifier_trainer = build_sentence_pair_model( model_cls, trainer_cls, len(vocabulary), FLAGS.model_dim, FLAGS.word_embedding_dim, FLAGS.seq_length, num_classes, initial_embeddings, FLAGS.embedding_keep_rate, FLAGS.gpu) if ".ckpt" in FLAGS.ckpt_path: checkpoint_path = FLAGS.ckpt_path else: checkpoint_path = os.path.join(FLAGS.ckpt_path, FLAGS.experiment_name + ".ckpt") if os.path.isfile(checkpoint_path): # TODO: Check that resuming works fine with tf summaries. logger.Log("Found checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format( step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 if FLAGS.write_summaries: from spinn.tf_logger import TFLogger train_summary_logger = TFLogger(summary_dir=os.path.join( FLAGS.summary_dir, FLAGS.experiment_name, 'train')) dev_summary_logger = TFLogger(summary_dir=os.path.join( FLAGS.summary_dir, FLAGS.experiment_name, 'dev')) # Do an evaluation-only run. if only_forward: raise Exception("Not implemented for chainer.") else: # Train logger.Log("Training.") classifier_trainer.init_optimizer( clip=FLAGS.clipping_max_value, decay=FLAGS.l2_lambda, lr=FLAGS.learning_rate, ) # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) avg_class_acc = 0 avg_trans_acc = 0 for step in range(step, FLAGS.training_steps): X_batch, transitions_batch, y_batch, num_transitions_batch = training_data_iter.next( ) learning_rate = FLAGS.learning_rate * ( FLAGS.learning_rate_decay_per_10k_steps**(step / 10000.0)) # Reset cached gradients. classifier_trainer.optimizer.zero_grads() # Calculate loss and update parameters. ret = classifier_trainer.forward( { "sentences": X_batch, "transitions": transitions_batch, }, y_batch, train=True, predict=False) y, loss, class_acc, transition_acc = ret # Boilerplate for calculating loss. xent_cost_val = loss.data transition_cost_val = 0 avg_trans_acc += transition_acc avg_class_acc += class_acc if FLAGS.show_intermediate_stats and step % 5 == 0 and step % FLAGS.statistics_interval_steps > 0: print("Accuracies so far : ", avg_class_acc / (step % FLAGS.statistics_interval_steps), avg_trans_acc / (step % FLAGS.statistics_interval_steps)) total_cost_val = xent_cost_val + transition_cost_val loss.backward() if FLAGS.gradient_check: def get_loss(): _, check_loss, _, _ = classifier_trainer.forward( { "sentences": X_batch, "transitions": transitions_batch, }, y_batch, train=True, predict=False) return check_loss gradient_check(classifier_trainer.model, get_loss) try: classifier_trainer.update() except: import ipdb ipdb.set_trace() pass # Accumulate accuracy for current interval. action_acc_val = 0.0 acc_val = float(classifier_trainer.model.accuracy.data) if FLAGS.write_summaries: train_summary_logger.log(step=step, loss=total_cost_val, accuracy=acc_val) progress_bar.step( i=max(0, step - 1) % FLAGS.statistics_interval_steps + 1, total=FLAGS.statistics_interval_steps) if step % FLAGS.statistics_interval_steps == 0: progress_bar.finish() avg_class_acc /= FLAGS.statistics_interval_steps avg_trans_acc /= FLAGS.statistics_interval_steps logger.Log( "Step: %i\tAcc: %f\t%f\tCost: %5f %5f %5f %s" % (step, avg_class_acc, avg_trans_acc, total_cost_val, xent_cost_val, transition_cost_val, "l2-not-exposed")) avg_trans_acc = 0 avg_class_acc = 0 if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(classifier_trainer, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and ( 1 - acc) < 0.99 * best_dev_error and step > 1000: best_dev_error = 1 - acc logger.Log( "Checkpointing with new best dev accuracy of %f" % acc) classifier_trainer.save(checkpoint_path, step, best_dev_error) if FLAGS.write_summaries: dev_summary_logger.log(step=step, loss=0.0, accuracy=acc) progress_bar.reset() if FLAGS.profile and step >= FLAGS.profile_steps: break