def test_valid_transitions_eval(self): # TODO: Check on shorter length. seq_length = 150 for_rnn = False use_left_padding = True data_manager = load_snli_data raw_data, _ = data_manager.load_data(snli_data_path) data_sets = [(snli_data_path, raw_data)] vocabulary = util.BuildVocabulary( raw_data, data_sets, embedding_data_path, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, word_embedding_dim, embedding_data_path) EOS_TOKEN = vocabulary["."] data = util.PreprocessDataset( raw_data, vocabulary, seq_length, data_manager, eval_mode=True, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=for_rnn, use_left_padding=use_left_padding) tokens, transitions, labels, num_transitions, _ = data for s, ts, (num_hyp_t, num_prem_t) in zip(tokens, transitions, num_transitions): hyp_t = ts[:, 0] prem_t = ts[:, 1] assert t_is_valid(hyp_t) assert t_is_valid(prem_t)
def test_preprocess(self): seq_length = 25 for_rnn = False data_manager = load_snli_data raw_data, _ = data_manager.load_data(snli_data_path) data_sets = [(snli_data_path, raw_data)] vocabulary = util.BuildVocabulary( raw_data, data_sets, embedding_data_path, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, word_embedding_dim, embedding_data_path) EOS_TOKEN = vocabulary["."] data = util.PreprocessDataset( raw_data, vocabulary, seq_length, data_manager, eval_mode=False, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=for_rnn) tokens, transitions, labels, num_transitions = data[:4] # Filter examples that don't have lengths <= seq_length assert tokens.shape == (2, seq_length, 2) assert transitions.shape == (2, seq_length, 2) for s, ts, (num_hyp_t, num_prem_t) in zip(tokens, transitions, num_transitions): hyp_s = s[:, 0] prem_s = s[:, 1] hyp_t = ts[:, 0] prem_t = ts[:, 1] # The sentences should start with a word and end with an EOS. assert s_is_left_to_right(hyp_s, EOS_TOKEN) assert s_is_left_to_right(prem_s, EOS_TOKEN) # The sentences should be padded on the right. assert not s_is_left_padded(hyp_s) assert not s_is_left_padded(prem_s) # The num_transitions should count non-skip transitions assert len([x for x in hyp_t if x != T_SKIP]) == num_hyp_t assert len([x for x in prem_t if x != T_SKIP]) == num_prem_t # The transitions should start with SKIP and end with REDUCE (ignoring SKIPs). assert t_is_left_to_right(hyp_t) assert t_is_left_to_right(prem_t) # The transitions should be padded on the left. assert t_is_left_padded(hyp_t) assert t_is_left_padded(prem_t)
def test_load_embed(self): data_manager = load_snli_data raw_data, _ = data_manager.load_data(snli_data_path) data_sets = [(snli_data_path, raw_data)] vocabulary = util.BuildVocabulary( raw_data, data_sets, embedding_data_path, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, word_embedding_dim, embedding_data_path) assert initial_embeddings.shape == (10, 5)
def test_preprocess(self): seq_length = 30 simple = False data_manager = load_sst_data raw_data = data_manager.load_data(sst_data_path, eval_mode=True) data_sets = [(sst_data_path, raw_data)] vocabulary = util.BuildVocabulary( raw_data, data_sets, embedding_data_path, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, word_embedding_dim, embedding_data_path) EOS_TOKEN = vocabulary["."] data = util.PreprocessDataset( raw_data, vocabulary, seq_length, data_manager, eval_mode=False, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=simple) tokens, transitions, labels, num_transitions = data[:4] # Filter examples that don't have lengths <= seq_length assert tokens.shape == (10, seq_length) assert transitions.shape == (10, seq_length) for s, ts, num_t in zip(tokens, transitions, num_transitions): # The sentences should start with a word and end with an EOS. assert s_is_left_to_right(s, EOS_TOKEN) # The sentences should be padded on the right. assert not s_is_left_padded(s) # The num_transitions should count non-skip transitions assert len([x for x in ts if x != T_SKIP]) == num_t # The transitions should start with SKIP and end with REDUCE # (ignoring SKIPs). assert t_is_left_to_right(ts) # The transitions should be padded on the left. assert t_is_left_padded(ts)
def test_valid_transitions_train(self): # TODO: Check on shorter length. seq_length = 150 simple = False data_manager = load_nli_data raw_data = data_manager.load_data(nli_data_path) data_sets = [(nli_data_path, raw_data)] vocabulary = util.BuildVocabulary( raw_data, data_sets, embedding_data_path, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, word_embedding_dim, embedding_data_path) data = util.PreprocessDataset( raw_data, vocabulary, seq_length, data_manager, eval_mode=False, logger=MockLogger(), sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=simple) tokens, transitions, labels, num_transitions = data[:4] for s, ts, (num_hyp_t, num_prem_t) in zip(tokens, transitions, num_transitions): hyp_t = ts[:, 0] prem_t = ts[:, 1] assert t_is_valid(hyp_t) assert t_is_valid(prem_t)
def load_data_and_embeddings( FLAGS, data_manager, logger, training_data_path, eval_data_path): def choose_train(x): return True if FLAGS.train_genre is not None: def choose_train(x): return x.get('genre') == FLAGS.train_genre def choose_eval(x): return True if FLAGS.eval_genre is not None: def choose_eval(x): return x.get('genre') == FLAGS.eval_genre if not FLAGS.expanded_eval_only_mode: raw_training_data = data_manager.load_data( training_data_path, FLAGS.lowercase, eval_mode=False) else: raw_training_data = None raw_eval_sets = [] for path in eval_data_path.split(':'): raw_eval_data = data_manager.load_data( path, FLAGS.lowercase, choose_eval, eval_mode=True) raw_eval_sets.append((path, raw_eval_data)) # Prepare the vocabulary. if not data_manager.FIXED_VOCABULARY: vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: vocabulary = data_manager.FIXED_VOCABULARY logger.Log("In fixed vocabulary mode. Training embeddings from scratch.") # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") if raw_training_data is not None: training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_cropping, pad_from_left=pad_from_left()) if raw_training_data is not None else None training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) if raw_training_data is not None else None training_data_length = len(training_data[0]) else: training_data_iter = None training_data_length = 0 # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_eval_cropping, pad_from_left=pad_from_left()) eval_it = util.MakeEvalIterator( eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) return vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path, eval_data_path): def choose_train(x): return True if FLAGS.train_genre is not None: def choose_train(x): return x.get('genre') == FLAGS.train_genre def choose_eval(x): return True if FLAGS.eval_genre is not None: def choose_eval(x): return x.get('genre') == FLAGS.eval_genre if FLAGS.data_type == "nli": # Load the data. raw_training_data, vocabulary = data_manager.load_data( training_data_path, FLAGS.lowercase, choose_train) else: # Load the data. raw_training_data, vocabulary = data_manager.load_data( training_data_path, FLAGS.lowercase) if FLAGS.data_type == "nli": # Load the eval data. raw_eval_sets = [] raw_eval_data, _ = data_manager.load_data(eval_data_path, FLAGS.lowercase, choose_eval) raw_eval_sets.append((eval_data_path, raw_eval_data)) else: # Load the eval data. raw_eval_sets = [] raw_eval_data, _ = data_manager.load_data(eval_data_path, FLAGS.lowercase) raw_eval_sets.append((eval_data_path, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) return vocabulary, initial_embeddings, training_data_iter, eval_iterators
def load_data_and_embeddings(FLAGS, data_manager, logger, training_data_path, eval_data_path): def choose_train(x): return True # NOTE # if FLAGS.train_genre is not None: # def choose_train(x): return x.get('genre') == FLAGS.train_genre def choose_eval(x): return True # NOTE # if FLAGS.eval_genre is not None: # def choose_eval(x): return x.get('genre') == FLAGS.eval_genre if not FLAGS.expanded_eval_only_mode: if FLAGS.data_type == "nli": # Load the data. raw_training_data = data_manager.load_data( training_data_path, FLAGS.lowercase, choose_train, mode=FLAGS.transition_mode, tree_joint=FLAGS.tree_joint, distance_type=FLAGS.distance_type) # only for nli data now else: # Load the data. raw_training_data = data_manager.load_data( training_data_path, FLAGS.lowercase, mode=FLAGS.transition_mode) else: raw_training_data = None if FLAGS.data_type == "nli": # Load the eval data. raw_eval_sets = [] for path in eval_data_path.split(':'): raw_eval_data = data_manager.load_data( path, FLAGS.lowercase, choose_eval, mode=FLAGS.transition_mode, tree_joint=FLAGS.tree_joint, distance_type=FLAGS.distance_type) # only for nli data now raw_eval_sets.append((path, raw_eval_data)) # print raw_eval_data[1].keys() #exit(1) else: # Load the eval data. raw_eval_sets = [] for path in eval_data_path.split(':'): raw_eval_data = data_manager.load_data(path, FLAGS.lowercase, mode=FLAGS.transition_mode) raw_eval_sets.append((path, raw_eval_data)) # Prepare the vocabulary. if not data_manager.FIXED_VOCABULARY: logger.Log( "In open vocabulary mode. Using loaded embeddings without fine-tuning." ) vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: vocabulary = data_manager.FIXED_VOCABULARY logger.Log("In fixed vocabulary mode. Training embeddings.") # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_cropping, pad_from_left=pad_from_left(), tree_joint=FLAGS.tree_joint) if raw_training_data is not None else None training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA # ,train_seed=FLAGS.train_seed ) if raw_training_data is not None else None # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, simple=sequential_only(), allow_cropping=FLAGS.allow_eval_cropping, pad_from_left=pad_from_left(), tree_joint=FLAGS.tree_joint) eval_it = util.MakeEvalIterator( eval_data, FLAGS.batch_size * FLAGS.sample_num, # keep the eval running speed FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) return vocabulary, initial_embeddings, training_data_iter, eval_iterators
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. if FLAGS.data_type == "bl": data_manager = load_boolean_data elif FLAGS.data_type == "sst": data_manager = load_sst_data elif FLAGS.data_type == "snli": data_manager = load_snli_data elif FLAGS.data_type == "arithmetic": data_manager = load_simple_data else: logger.Log("Bad data type.") return pp = pprint.PrettyPrinter(indent=4) logger.Log("Flag values:\n" + pp.pformat(FLAGS.FlagValuesDict())) # Make Metrics Logger. metrics_path = "{}/{}".format(FLAGS.metrics_path, FLAGS.experiment_name) if not os.path.exists(metrics_path): os.makedirs(metrics_path) metrics_logger = MetricsLogger(metrics_path) M = Accumulator(maxlen=FLAGS.deque_length) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only(), use_left_padding=FLAGS.use_left_padding) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only(), use_left_padding=FLAGS.use_left_padding) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Choose model. model_specific_params = {} logger.Log("Building model.") if FLAGS.model_type == "CBOW": model_module = spinn.cbow elif FLAGS.model_type == "RNN": model_module = spinn.plain_rnn elif FLAGS.model_type == "SPINN": model_module = spinn.fat_stack elif FLAGS.model_type == "RLSPINN": model_module = spinn.rl_spinn elif FLAGS.model_type == "RAESPINN": model_module = spinn.rae_spinn elif FLAGS.model_type == "GENSPINN": model_module = spinn.gen_spinn elif FLAGS.model_type == "ATTSPINN": model_module = spinn.att_spinn model_specific_params['using_diff_in_mlstm'] = FLAGS.using_diff_in_mlstm model_specific_params['using_prod_in_mlstm'] = FLAGS.using_prod_in_mlstm model_specific_params['using_null_in_attention'] = FLAGS.using_null_in_attention attlogger = logging.getLogger('spinn.attention') attlogger.setLevel(logging.INFO) fh = logging.FileHandler(os.path.join(FLAGS.log_path, '{}.att.log'.format(FLAGS.experiment_name))) fh.setLevel(logging.INFO) fh.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s')) attlogger.addHandler(fh) else: raise Exception("Requested unimplemented model type %s" % FLAGS.model_type) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) if data_manager.SENTENCE_PAIR_DATA: trainer_cls = model_module.SentencePairTrainer model_cls = model_module.SentencePairModel use_sentence_pair = True else: trainer_cls = model_module.SentenceTrainer model_cls = model_module.SentenceModel num_classes = len(data_manager.LABEL_MAP) use_sentence_pair = False model = model_cls(model_dim=FLAGS.model_dim, word_embedding_dim=FLAGS.word_embedding_dim, vocab_size=vocab_size, initial_embeddings=initial_embeddings, num_classes=num_classes, mlp_dim=FLAGS.mlp_dim, embedding_keep_rate=FLAGS.embedding_keep_rate, classifier_keep_rate=FLAGS.semantic_classifier_keep_rate, tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim, transition_weight=FLAGS.transition_weight, encode_style=FLAGS.encode_style, encode_reverse=FLAGS.encode_reverse, encode_bidirectional=FLAGS.encode_bidirectional, encode_num_layers=FLAGS.encode_num_layers, use_sentence_pair=use_sentence_pair, use_skips=FLAGS.use_skips, lateral_tracking=FLAGS.lateral_tracking, use_tracking_in_composition=FLAGS.use_tracking_in_composition, use_difference_feature=FLAGS.use_difference_feature, use_product_feature=FLAGS.use_product_feature, num_mlp_layers=FLAGS.num_mlp_layers, mlp_bn=FLAGS.mlp_bn, rl_mu=FLAGS.rl_mu, rl_baseline=FLAGS.rl_baseline, rl_reward=FLAGS.rl_reward, rl_weight=FLAGS.rl_weight, predict_leaf=FLAGS.predict_leaf, gen_h=FLAGS.gen_h, model_specific_params=model_specific_params, ) # Build optimizer. if FLAGS.optimizer_type == "Adam": optimizer = optim.Adam(model.parameters(), lr=FLAGS.learning_rate, betas=(0.9, 0.999), eps=1e-08) elif FLAGS.optimizer_type == "RMSprop": optimizer = optim.RMSprop(model.parameters(), lr=FLAGS.learning_rate, eps=1e-08) else: raise NotImplementedError # Build trainer. classifier_trainer = trainer_cls(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = classifier_trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # Print model size. logger.Log("Architecture: {}".format(model)) total_params = sum([reduce(lambda x, y: x * y, w.size(), 1.0) for w in model.parameters()]) logger.Log("Total params: {}".format(total_params)) # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, metrics_logger, step, vocabulary) else: # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = training_data_iter.next() if FLAGS.truncate_train_batch: X_batch, transitions_batch = truncate( X_batch, transitions_batch, num_transitions_batch) total_tokens = num_transitions_batch.ravel().sum() # Reset cached gradients. optimizer.zero_grad() # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) A.add('class_acc', class_acc) M.add('class_acc', class_acc) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss/accuracy. transition_acc = model.transition_acc if hasattr(model, 'transition_acc') else 0.0 transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None rl_loss = model.rl_loss if hasattr(model, 'rl_loss') else None policy_loss = model.policy_loss if hasattr(model, 'policy_loss') else None rae_loss = model.spinn.rae_loss if hasattr(model.spinn, 'rae_loss') else None leaf_loss = model.spinn.leaf_loss if hasattr(model.spinn, 'leaf_loss') else None gen_loss = model.spinn.gen_loss if hasattr(model.spinn, 'gen_loss') else None # Force Transition Loss Optimization if FLAGS.force_transition_loss: model.optimize_transition_loss = True # Accumulate stats for transition accuracy. if transition_loss is not None: preds = [m["t_preds"] for m in model.spinn.memories] truth = [m["t_given"] for m in model.spinn.memories] A.add('preds', preds) A.add('truth', truth) # Accumulate stats for leaf prediction accuracy. if leaf_loss is not None: A.add('leaf_acc', model.spinn.leaf_acc) # Accumulate stats for word prediction accuracy. if gen_loss is not None: A.add('gen_acc', model.spinn.gen_acc) # Note: Keep track of transition_acc, although this is a naive average. # Should be weighted by length of sequences in batch. M.add('transition_acc', transition_acc) # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Boilerplate for calculating loss values. xent_cost_val = xent_loss.data[0] transition_cost_val = transition_loss.data[0] if transition_loss is not None else 0.0 l2_cost_val = l2_loss.data[0] if l2_loss is not None else 0.0 rl_cost_val = rl_loss.data[0] if rl_loss is not None else 0.0 policy_cost_val = policy_loss.data[0] if policy_loss is not None else 0.0 rae_cost_val = rae_loss.data[0] if rae_loss is not None else 0.0 leaf_cost_val = leaf_loss.data[0] if leaf_loss is not None else 0.0 gen_cost_val = gen_loss.data[0] if gen_loss is not None else 0.0 # Accumulate Total Loss Data total_cost_val = 0.0 total_cost_val += xent_cost_val if transition_loss is not None and model.optimize_transition_loss: total_cost_val += transition_cost_val total_cost_val += l2_cost_val total_cost_val += rl_cost_val total_cost_val += policy_cost_val total_cost_val += rae_cost_val total_cost_val += leaf_cost_val total_cost_val += gen_cost_val M.add('total_cost', total_cost_val) M.add('xent_cost', xent_cost_val) M.add('transition_cost', transition_cost_val) M.add('l2_cost', l2_cost_val) # Logging for RL rl_keys = ['rl_loss', 'policy_loss', 'norm_rewards', 'norm_baseline', 'norm_advantage'] for k in rl_keys: if hasattr(model, k): val = getattr(model, k) val = val.data[0] if isinstance(val, Variable) else val M.add(k, val) # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss if rl_loss is not None: total_loss += rl_loss if policy_loss is not None: total_loss += policy_loss if rae_loss is not None: total_loss += rae_loss if leaf_loss is not None: total_loss += leaf_loss if gen_loss is not None: total_loss += gen_loss # Useful for debugging gradient flow. if FLAGS.debug: losses = [('total_loss', total_loss), ('xent_loss', xent_loss)] if l2_loss is not None: losses.append(('l2_loss', l2_loss)) if transition_loss is not None and model.optimize_transition_loss: losses.append(('transition_loss', transition_loss)) if rl_loss is not None: losses.append(('rl_loss', rl_loss)) if policy_loss is not None: losses.append(('policy_loss', policy_loss)) debug_gradient(model, losses) import ipdb; ipdb.set_trace() # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() avg_class_acc = A.get_avg('class_acc') if transition_loss is not None: all_preds = np.array(flatten(A.get('preds'))) all_truth = np.array(flatten(A.get('truth'))) avg_trans_acc = (all_preds == all_truth).sum() / float(all_truth.shape[0]) else: avg_trans_acc = 0.0 if leaf_loss is not None: avg_leaf_acc = A.get_avg('leaf_acc') else: avg_leaf_acc = 0.0 if gen_loss is not None: avg_gen_acc = A.get_avg('gen_acc') else: avg_gen_acc = 0.0 time_metric = time_per_token(A.get('total_tokens'), A.get('total_time')) stats_args = { "step": step, "class_acc": avg_class_acc, "transition_acc": avg_trans_acc, "total_cost": total_cost_val, "xent_cost": xent_cost_val, "transition_cost": transition_cost_val, "l2_cost": l2_cost_val, "rl_cost": rl_cost_val, "policy_cost": policy_cost_val, "rae_cost": rae_cost_val, "leaf_acc": avg_leaf_acc, "leaf_cost": leaf_cost_val, "gen_acc": avg_gen_acc, "gen_cost": gen_cost_val, "time": time_metric, } stats_str = "Step: {step}" # Accuracy Component. stats_str += " Acc: {class_acc:.5f} {transition_acc:.5f}" if leaf_loss is not None: stats_str += " leaf{leaf_acc:.5f}" if gen_loss is not None: stats_str += " gen{gen_acc:.5f}" # Cost Component. stats_str += " Cost: {total_cost:.5f} {xent_cost:.5f} {transition_cost:.5f} {l2_cost:.5f}" if rl_loss is not None: stats_str += " r{rl_cost:.5f}" if policy_loss is not None: stats_str += " p{policy_cost:.5f}" if rae_loss is not None: stats_str += " rae{rae_cost:.5f}" if leaf_loss is not None: stats_str += " leaf{leaf_cost:.5f}" if gen_loss is not None: stats_str += " gen{gen_cost:.5f}" # Time Component. stats_str += " Time: {time:.5f}" logger.Log(stats_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, metrics_logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) classifier_trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") classifier_trainer.save(standard_checkpoint_path, step, best_dev_error) if step % FLAGS.metrics_interval_steps == 0: m_keys = M.cache.keys() for k in m_keys: metrics_logger.Log(k, M.get_avg(k), step) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager) # Build trainer. trainer = ModelTrainer(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), the_gpu.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step, vocabulary) else: # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4] model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4] total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type == "RLSPINN": model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay) # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)