def run(only_forward=False): logger = afs_safe_logger.ProtoLogger( log_path(FLAGS), print_formatter=create_log_formatter( True, False), write_proto=FLAGS.write_proto_to_log) header = pb.SpinnHeader() data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Get Data and Embeddings vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path) ''' f = open("./vocab.txt", "w") for k in vocabulary: f.write("{0}\t{1}\n".format(k, vocabulary[k])) f.close() ''' # Build model. vocab_size = len(vocabulary) num_classes = len(set(data_manager.LABEL_MAP.values())) model = init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header) epoch_length = int(training_data_length / FLAGS.batch_size) trainer = ModelTrainer(model, logger, epoch_length, vocabulary, FLAGS) header.start_step = trainer.step header.start_time = int(time.time()) # Do an evaluation-only run. logger.LogHeader(header) # Start log_entry logging. if only_forward: log_entry = pb.SpinnEntry() for index, eval_set in enumerate(eval_iterators): log_entry.Clear() evaluate( FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary, show_sample=True, eval_index=index) print(log_entry) logger.LogEntry(log_entry) else: train_loop( FLAGS, model, trainer, training_data_iter, eval_iterators, logger, vocabulary)
def run(only_forward=False): logger = afs_safe_logger.Logger(log_path(FLAGS)) data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Get Data and Embeddings vocabulary, initial_embeddings, training_data_iter, eval_iterators = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), FLAGS.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Do an evaluation-only run. if only_forward: eval_str = eval_format(model) logger.Log("Eval-Format: {}".format(eval_str)) eval_extra_str = eval_extra_format(model) logger.Log("Eval-Extra-Format: {}".format(eval_extra_str)) for index, eval_set in enumerate(eval_iterators): acc = evaluate(FLAGS, model, data_manager, eval_set, index, logger, step, vocabulary) else: train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error)
def run(only_forward=False): logger = afs_safe_logger.ProtoLogger(log_path(FLAGS), print_formatter=create_log_formatter(True, False), write_proto=FLAGS.write_proto_to_log) header = pb.SpinnHeader() data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) flags_dict = sorted(list(FLAGS.FlagValuesDict().items())) for k, v in flags_dict: flag = header.flags.add() flag.key = k flag.value = str(v) # Get Data and Embeddings vocabulary, initial_embeddings, training_data_iter, eval_iterators = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path) # Build model. vocab_size = len(vocabulary) num_classes = len(set(data_manager.LABEL_MAP.values())) model, optimizer, trainer = init_model( FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log( "Resuming at step: {} with best dev accuracy: {}".format( step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log( "Resuming at step: {} with best dev accuracy: {}".format( step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 header.start_step = step header.start_time = int(time.time()) # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), FLAGS.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Do an evaluation-only run. logger.LogHeader(header) # Start log_entry logging. if only_forward: log_entry = pb.SpinnEntry() for index, eval_set in enumerate(eval_iterators): log_entry.Clear() evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary, show_sample=True, eval_index=index) print(log_entry) logger.LogEntry(log_entry) else: train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error, vocabulary)
def run(only_forward=False): logger = afs_safe_logger.ProtoLogger(log_path(FLAGS), print_formatter=create_log_formatter( True, False), write_proto=FLAGS.write_proto_to_log) header = pb.SpinnHeader() data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) flags_dict = sorted(list(FLAGS.FlagValuesDict().items())) for k, v in flags_dict: flag = header.flags.add() flag.key = k flag.value = str(v) if not FLAGS.expanded_eval_only_mode: # Get Data and Embeddings for training preprocessed_data_path = os.path.join( FLAGS.ckpt_path, 'allnli_preprocessed_data_prpn-{}_train-{:d}-valid-{:d}_batch-{:d}_dist-{}.dat' .format(FLAGS.prpn_name, FLAGS.seq_length, FLAGS.eval_seq_length, FLAGS.batch_size, FLAGS.tree_joint)) if os.path.isfile(preprocessed_data_path): print 'Reading dumped preprocessed data' vocabulary, initial_embeddings, picked_train_iter_pack, eval_iterators = cPickle.load( open(preprocessed_data_path, "rb")) else: vocabulary, initial_embeddings, picked_train_iter_pack, eval_iterators = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path, ) print 'Dumping data' cPickle.dump( (vocabulary, initial_embeddings, picked_train_iter_pack, list(eval_iterators)), open(preprocessed_data_path, 'wb')) print 'Dumping done' train_sources, train_batches = picked_train_iter_pack def unpack_pickled_train_iter(sources, batches): ''' ''' num_batches = len(batches) idx = -1 order = range(num_batches) random.shuffle(order) while True: idx += 1 if idx >= num_batches: # Start another epoch. num_batches = len(batches) idx = 0 order = range(num_batches) random.shuffle(order) batch_indices = batches[order[idx]] # yield tuple(source[batch_indices] for source in sources if source is not None) yield tuple( source[batch_indices] if source is not None else None for source in sources) # for gumbel tree model, the dist will be None training_data_iter = unpack_pickled_train_iter(train_sources, train_batches) else: # Get Data and Embeddings for test only vocabulary, initial_embeddings, training_data_iter, eval_iterators = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path, ) # Build model. vocab_size = len(vocabulary) num_classes = len(set(data_manager.LABEL_MAP.values())) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) best_parsing_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True, parsing=True) sl_checkpoint_path = get_checkpoint_path_for_sl(FLAGS.ckpt_path, FLAGS.experiment_name, step=FLAGS.load_sl_step) # Load checkpoint if available. if FLAGS.customize_ckpt: customize_ckpt_path = FLAGS.customize_ckpt_path logger.Log("Found pretrained customized checkpoint, restoring.") step, best_dev_error, best_dev_f1_error = trainer.load( customize_ckpt_path, cpu=FLAGS.gpu < 0, continue_train=FLAGS.continue_train) best_dev_step = 0 elif FLAGS.load_best: if FLAGS.test_type == 'classification' and os.path.isfile( best_checkpoint_path): logger.Log("Found best classification checkpoint, restoring.") step, best_dev_error, dev_f1_error = trainer.load( best_checkpoint_path, cpu=FLAGS.gpu < 0) logger.Log( "Resuming at step: {} best dev accuracy: {} with dev f1: {}". format(step, 1. - best_dev_error, 1. - dev_f1_error)) step = 0 best_dev_step = 0 best_dev_f1_error = dev_f1_error elif os.path.isfile(best_parsing_checkpoint_path): logger.Log("Found best parsing checkpoint, restoring.") step, dev_error, best_dev_f1_error = trainer.load( best_parsing_checkpoint_path, cpu=FLAGS.gpu < 0) logger.Log( "Resuming at step: {} best f1: {} with dev accuracy: {}". format(step, 1. - best_dev_f1_error, 1. - dev_error)) else: raise ValueError('Can\'t find the best checkpoint.') elif FLAGS.load_sl: logger.Log( "Found pretrained SL checkpoint at step {:d}, restoring.".format( FLAGS.load_sl_step)) step, best_dev_error, best_dev_f1_error = trainer.load( standard_checkpoint_path, cpu=FLAGS.gpu < 0, continue_train=FLAGS.continue_train) best_dev_step = 0 elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error, best_dev_f1_error = trainer.load( standard_checkpoint_path, cpu=FLAGS.gpu < 0) logger.Log( "Resuming at step: {} previously best dev accuracy: {} and previously best f1: {}" .format(step, 1. - best_dev_error, 1. - best_dev_f1_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 best_dev_step = 0 best_dev_f1_error = 1.0 # for best parsing checkpoint header.start_step = step header.start_time = int(time.time()) # # Right-branching trick. # DefaultUniformInitializer(model.binary_tree_lstm.comp_query.weight) # set temperature model.binary_tree_lstm.temperature_param.data = torch.Tensor([[0.2]]) # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), FLAGS.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Do an evaluation-only run. logger.LogHeader(header) # Start log_entry logging. if only_forward: log_entry = pb.SpinnEntry() for index, eval_set in enumerate(eval_iterators): log_entry.Clear() evaluate(FLAGS, model, data_manager, eval_set, log_entry, logger, step, vocabulary, show_sample=True, eval_index=index) print(log_entry) logger.LogEntry(log_entry) else: best_dev_step = 0 train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error, best_dev_step, best_dev_f1_error, vocabulary)
def run(only_forward=False): logger = afs_safe_logger.ProtoLogger(log_path(FLAGS)) header = pb.SpinnHeader() data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) flags_dict = sorted(list(FLAGS.FlagValuesDict().items())) for k, v in flags_dict: flag = header.flags.add() flag.key = k flag.value = str(v) # Get Data and Embeddings vocabulary, initial_embeddings, training_data_iter, eval_iterators = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header) # Checking if experiment with petrurbation id 0 has a checkpoint perturbation_name = FLAGS.experiment_name + "_p" + '0' best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, perturbation_name, best=True) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, perturbation_name, best=False) ckpt_names = [] if os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoints, they will be restored.") ckpt_names = get_pert_names(best=True) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found standard checkpoints, they will be restored.") ckpt_names = get_pert_names(best=False) else: assert not only_forward, "Can't run an eval-only run without best checkpoints. Supply best checkpoint(s)." true_step = 0 best_dev_error = 1.0 reload_ev_step = 0 header.start_step = step header.start_time = int(time.time()) header.model_label = perturbation_name # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), FLAGS.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) logger.LogHeader(header) # Start log_entry logging. # Do an evaluation-only run. if only_forward: assert len( ckpt_names ) == 0, "Can not run forward pass without best checkpoints supplied." log_entry = pb.SpinnEntry() restore_queue = mp.Queue() processes_restore = [] while ckpt_names: pert_name = ckpt_names.pop() path = os.path.join(FLAGS.ckpt_path, pert_name) name = pert_name.replace('.ckpt_best', '') p_restore = mp.Process(target=restore, args=(logger, trainer, restore_queue, FLAGS, name, path)) p_restore.start() processes_restore.append(p_restore) assert len(ckpt_names) == 0 results = [restore_queue.get() for p in processes_restore] reload_ev_step = results[0][0] while all_models: p_checkpoint = all_models.pop() p_model = p_checkpoint[2] true_step = p_checkpoint[1] for index, eval_set in enumerate(eval_iterators): log_entry.Clear() evaluate(FLAGS, p_model, data_manager, eval_set, log_entry, true_step, vocabulary) print(log_entry) logger.LogEntry(log_entry) else: # Restore model, i.e. perturbation spawns, from best checkpoint, if it exists, or standard checkpoint. # Get dev-set accuracies so we can select which models to use for the next evolution step. if len(ckpt_names) != 0: logger.Log("Restoring models from best or standard checkpoints") processes_restore = [] restore_queue = mp.Queue() while ckpt_names: pert_name = ckpt_names.pop() path = os.path.join(FLAGS.ckpt_path, pert_name) name = pert_name.replace('.ckpt_best', '') p_restore = mp.Process(target=restore, args=(logger, trainer, restore_queue, FLAGS, name, path)) p_restore.start() processes_restore.append(p_restore) assert len(ckpt_names) == 0 results = [restore_queue.get() for p in processes_restore] reload_ev_step = results[0][0] + 1 # the next evolution step else: id_ = "B" chosen_models = [(reload_ev_step, true_step, id_, best_dev_error)] base = True # This is the "base" model results = [] for ev_step in range(reload_ev_step, FLAGS.es_steps): logger.Log("Evolution step: %i" % ev_step) # Choose root models for next generation using dev-set accuracy if len(results) != 0: base = False chosen_models = [] acc_order = [ i[0] for i in sorted(enumerate(results), key=lambda x: x[1][3], reverse=True) ] for i in range(FLAGS.es_num_episodes): id_ = acc_order[i] logger.Log( "Picking model %s to perturb for next evolution step." % results[id_][2]) chosen_models.append(results[id_]) # Flush results from previous generatrion results = [] processes = [] queue = mp.Queue() all_seeds, all_models = [], [] all_steps = [] all_dev_errs = [] for chosen_model in chosen_models: perturbation_id = chosen_model[2] random_seed, models = generate_seeds_and_models( trainer, model, perturbation_id, base=base) for i in range(len(models)): all_seeds.append(random_seed) all_steps.append(chosen_model[1]) all_dev_errs.append(chosen_model[3]) all_models += models assert len(all_seeds) == len(all_models) assert len(all_steps) == len(all_seeds) perturbation_id = 0 while all_models: perturbed_model = all_models.pop() true_step = all_steps.pop() best_dev_error = all_dev_errs.pop() p = mp.Process(target=rollout, args=(queue, perturbed_model, FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, true_step, best_dev_error, perturbation_id, ev_step)) p.start() processes.append(p) perturbation_id += 1 assert len(all_models) == 0, "All models where not trained!" # Run processes in queue and results = [queue.get() for p in processes] # Check to ensure the correct number of models where trained and saved if ev_step == 0: assert len(results) == FLAGS.es_num_episodes else: assert len(results) == FLAGS.es_num_episodes**2
def run(only_forward=False): logger = afs_safe_logger.ProtoLogger(log_path(FLAGS), print_formatter=create_log_formatter( True, False), write_proto=FLAGS.write_proto_to_log) header = pb.SpinnHeader() data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) flags_dict = sorted(list(FLAGS.FlagValuesDict().items())) for k, v in flags_dict: flag = header.flags.add() flag.key = k flag.value = str(v) # Get Data and Embeddings vocabulary, initial_embeddings, training_data_iter, eval_iterators = \ load_data_and_embeddings(FLAGS, data_manager, logger, FLAGS.training_data_path, FLAGS.eval_data_path) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header) # Checking if experiment with petrurbation id 0 has a checkpoint perturbation_name = FLAGS.experiment_name + "_p" + '0' best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, perturbation_name, best=True) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, perturbation_name, best=False) ckpt_names = [] if os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoints, they will be restored.") ckpt_names = get_pert_names(best=True) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found standard checkpoints, they will be restored.") ckpt_names = get_pert_names(best=False) else: assert not only_forward, "Can't run an eval-only run without best checkpoints. Supply best checkpoint(s)." true_step = 0 best_dev_error = 1.0 best_dev_step = 0 reload_ev_step = 0 if FLAGS.mirror: true_num_episodes = FLAGS.es_num_episodes * 2 else: true_num_episodes = FLAGS.es_num_episodes # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), FLAGS.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) logger.LogHeader(header) # Start log_entry logging. # Do an evaluation-only run. if only_forward: assert len( ckpt_names ) != 0, "Can not run forward pass without best checkpoints supplied." log_entry = pb.SpinnEntry() restore_queue = mp.Queue() processes_restore = [] while ckpt_names: pert_name = ckpt_names.pop() path = os.path.join(FLAGS.ckpt_path, pert_name) name = pert_name.replace('.ckpt_best', '') p_restore = mp.Process(target=restore, args=(logger, trainer, restore_queue, FLAGS, name, path)) p_restore.start() processes_restore.append(p_restore) assert len(ckpt_names) == 0 results = [restore_queue.get() for p in processes_restore] assert results != 0 acc_order = [ i[0] for i in sorted(enumerate(results), key=lambda x: x[1][3]) ] best_id = acc_order[0] best_name = FLAGS.experiment_name + "_p" + str(best_id) best_path = os.path.join(FLAGS.ckpt_path, best_name + ".ckpt_best") ev_step, true_step, dev_error, best_dev_step = trainer.load( best_path, cpu=FLAGS.gpu < 0) print "Picking best perturbation/model %s to run evaluation, with best dev accuracy of %f" % ( best_name, 1. - dev_error) for index, eval_set in enumerate(eval_iterators): log_entry.Clear() evaluate(FLAGS, model, eval_set, log_entry, true_step, vocabulary, show_sample=True, eval_index=index) print(log_entry) logger.LogEntry(log_entry) # Train the model. else: # Restore model, i.e. perturbation spawns, from best checkpoint. # Get dev-set accuracies so we can select which models to use for the # next evolution step. if len(ckpt_names) != 0: logger.Log("Restoring models from best checkpoints") processes_restore = [] restore_queue = mp.Queue() while ckpt_names: pert_name = ckpt_names.pop() path = os.path.join(FLAGS.ckpt_path, pert_name) name = pert_name.replace('.ckpt_best', '') p_restore = mp.Process(target=restore, args=(logger, trainer, restore_queue, FLAGS, name, path)) p_restore.start() processes_restore.append(p_restore) assert len(ckpt_names) == 0 results = [restore_queue.get() for p in processes_restore] reload_ev_step = results[0][0] + 1 # the next evolution step else: id_ = "B" chosen_models = [(reload_ev_step, true_step, id_, best_dev_error, best_dev_step)] base = True # This is the "base" model results = [] for ev_step in range(reload_ev_step, FLAGS.es_steps): logger.Log("Evolution step: %i" % ev_step) # Downsample dev-set for evaluation runs during training eval_iterators_ = [] if FLAGS.eval_sample_size is not None: for file in eval_iterators: eval_filename = eval_iterators[0][0] eval_batches = eval_iterators[0][1] full = len(eval_batches) subsample = int(full * FLAGS.eval_sample_size) eval_batches = random.sample(eval_batches, subsample) eval_iterators_.append((eval_filename, eval_batches)) else: eval_iterators_ = eval_iterators # Choose root models for next generation using dev-set accuracy if len(results) != 0: base = False chosen_models = [] acc_order = [ i[0] for i in sorted(enumerate(results), key=lambda x: x[1][3]) ] for i in range(FLAGS.es_num_roots): id_ = acc_order[i] logger.Log( "Picking model %s to perturb for next evolution step." % results[id_][2]) chosen_models.append(results[id_]) # Early stopping based on current best model best_current = chosen_models[0] best_current_step = best_current[1] # true_step best_current_dev_step = best_current[4] # best_dev_step if (best_current_step - best_current_dev_step ) > FLAGS.early_stopping_steps_to_wait: logger.Log('No improvement after ' + str(FLAGS.early_stopping_steps_to_wait) + ' steps. Stopping training.') break # Flush results from previous generatrion results = [] processes = [] queue = mp.Queue() all_seeds, all_models, all_roots, all_steps, all_dev_errs, all_best_dev_steps = ( [] for i in range(6)) for chosen_model in chosen_models: perturbation_id = chosen_model[2] random_seed, models, true_step, best_dev_step = generate_seeds_and_models( trainer, model, perturbation_id, base=base) for i in range(len(models)): all_seeds.append(random_seed) all_steps.append(true_step) all_dev_errs.append(chosen_model[3]) all_roots.append(perturbation_id) all_best_dev_steps.append(best_dev_step) all_models += models assert len(all_seeds) == len(all_models) assert len(all_steps) == len(all_seeds) perturbation_id = 0 j = 0 while all_models: perturbed_model = all_models.pop() true_step = all_steps.pop() best_dev_error = all_dev_errs.pop() root_id = all_roots.pop() best_dev_step = all_best_dev_steps.pop() p = mp.Process( target=rollout, args=(queue, perturbed_model, FLAGS, model, optimizer, trainer, training_data_iter, eval_iterators_, logger, true_step, best_dev_error, perturbation_id, ev_step, header, root_id, vocabulary, best_dev_step)) p.start() processes.append(p) perturbation_id += 1 j += 1 assert len(all_models) == 0, "All models where not trained!" for p in processes: p.join() results = [queue.get() for p in processes] # Check to ensure the correct number of models where trained and saved if ev_step == 0: assert len(results) == true_num_episodes else: assert len(results) == true_num_episodes * FLAGS.es_num_roots
def run(only_forward=False): logger = afs_safe_logger.ProtoLogger(log_path(FLAGS), print_formatter=create_log_formatter( True, False), write_proto=FLAGS.write_proto_to_log) header = pb.SpinnHeader() data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Get Data and Embeddings vocabulary, initial_embeddings, training_data_iter, eval_iterators, training_data_length, target_vocabulary = \ load_data_and_embeddings(FLAGS, data_manager, logger, "", FLAGS.eval_data_path) # Build model. vocab_size = len(vocabulary) if FLAGS.data_type != "mt": num_classes = len(set(data_manager.LABEL_MAP.values())) else: num_classes = None model = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager, header, target_vocabulary=target_vocabulary) time_to_wait_to_lower_lr = min( 10000, int(training_data_length / FLAGS.batch_size)) trainer = ModelTrainer(model, logger, time_to_wait_to_lower_lr, vocabulary, FLAGS) header.start_step = trainer.step header.start_time = int(time.time()) # Do an evaluation-only run. logger.LogHeader(header) # Start log_entry logging. if only_forward: log_entry = pb.SpinnEntry() for index, eval_set in enumerate(eval_iterators): log_entry.Clear() evaluate(FLAGS, model, eval_set, log_entry, logger, trainer, vocabulary, show_sample=True, eval_index=index, target_vocabulary=target_vocabulary) print(log_entry) logger.LogEntry(log_entry) else: train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators, logger, vocabulary, target_vocabulary)
def run(only_forward=False): logger = afs_safe_logger.Logger(os.path.join(FLAGS.log_path, FLAGS.experiment_name) + ".log") # Select data format. data_manager = get_data_manager(FLAGS.data_type) logger.Log("Flag Values:\n" + json.dumps(FLAGS.FlagValuesDict(), indent=4, sort_keys=True)) # Load the data. raw_training_data, vocabulary = data_manager.load_data( FLAGS.training_data_path, FLAGS.lowercase) # Load the eval data. raw_eval_sets = [] if FLAGS.eval_data_path: for eval_filename in FLAGS.eval_data_path.split(":"): raw_eval_data, _ = data_manager.load_data(eval_filename, FLAGS.lowercase) raw_eval_sets.append((eval_filename, raw_eval_data)) # Prepare the vocabulary. if not vocabulary: logger.Log("In open vocabulary mode. Using loaded embeddings without fine-tuning.") train_embeddings = False vocabulary = util.BuildVocabulary( raw_training_data, raw_eval_sets, FLAGS.embedding_data_path, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) else: logger.Log("In fixed vocabulary mode. Training embeddings.") train_embeddings = True # Load pretrained embeddings. if FLAGS.embedding_data_path: logger.Log("Loading vocabulary with " + str(len(vocabulary)) + " words from " + FLAGS.embedding_data_path) initial_embeddings = util.LoadEmbeddingsFromText( vocabulary, FLAGS.word_embedding_dim, FLAGS.embedding_data_path) else: initial_embeddings = None # Trim dataset, convert token sequences to integer sequences, crop, and # pad. logger.Log("Preprocessing training data.") training_data = util.PreprocessDataset( raw_training_data, vocabulary, FLAGS.seq_length, data_manager, eval_mode=False, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) training_data_iter = util.MakeTrainingIterator( training_data, FLAGS.batch_size, FLAGS.smart_batching, FLAGS.use_peano, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA) # Preprocess eval sets. eval_iterators = [] for filename, raw_eval_set in raw_eval_sets: logger.Log("Preprocessing eval data: " + filename) eval_data = util.PreprocessDataset( raw_eval_set, vocabulary, FLAGS.eval_seq_length if FLAGS.eval_seq_length is not None else FLAGS.seq_length, data_manager, eval_mode=True, logger=logger, sentence_pair_data=data_manager.SENTENCE_PAIR_DATA, for_rnn=sequential_only()) eval_it = util.MakeEvalIterator(eval_data, FLAGS.batch_size, FLAGS.eval_data_limit, bucket_eval=FLAGS.bucket_eval, shuffle=FLAGS.shuffle_eval, rseed=FLAGS.shuffle_eval_seed) eval_iterators.append((filename, eval_it)) # Build model. vocab_size = len(vocabulary) num_classes = len(data_manager.LABEL_MAP) model, optimizer, trainer = init_model(FLAGS, logger, initial_embeddings, vocab_size, num_classes, data_manager) # Build trainer. trainer = ModelTrainer(model, optimizer) standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name) best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True) # Load checkpoint if available. if FLAGS.load_best and os.path.isfile(best_checkpoint_path): logger.Log("Found best checkpoint, restoring.") step, best_dev_error = trainer.load(best_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) elif os.path.isfile(standard_checkpoint_path): logger.Log("Found checkpoint, restoring.") step, best_dev_error = trainer.load(standard_checkpoint_path) logger.Log("Resuming at step: {} with best dev accuracy: {}".format(step, 1. - best_dev_error)) else: assert not only_forward, "Can't run an eval-only run without a checkpoint. Supply a checkpoint." step = 0 best_dev_error = 1.0 # GPU support. the_gpu.gpu = FLAGS.gpu if FLAGS.gpu >= 0: model.cuda() else: model.cpu() recursively_set_device(optimizer.state_dict(), the_gpu.gpu) # Debug def set_debug(self): self.debug = FLAGS.debug model.apply(set_debug) # Accumulate useful statistics. A = Accumulator(maxlen=FLAGS.deque_length) # Do an evaluation-only run. if only_forward: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step, vocabulary) else: # Build log format strings. model.train() X_batch, transitions_batch, y_batch, num_transitions_batch = get_batch(training_data_iter.next())[:4] model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) train_str = train_format(model) logger.Log("Train-Format: {}".format(train_str)) train_extra_str = train_extra_format(model) logger.Log("Train-Extra-Format: {}".format(train_extra_str)) # Train logger.Log("Training.") # New Training Loop progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar) progress_bar.step(i=0, total=FLAGS.statistics_interval_steps) for step in range(step, FLAGS.training_steps): model.train() start = time.time() batch = get_batch(training_data_iter.next()) X_batch, transitions_batch, y_batch, num_transitions_batch = batch[:4] total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)]) # Reset cached gradients. optimizer.zero_grad() if FLAGS.model_type == "RLSPINN": model.spinn.epsilon = FLAGS.rl_epsilon * math.exp(-step/FLAGS.rl_epsilon_decay) # Run model. output = model(X_batch, transitions_batch, y_batch, use_internal_parser=FLAGS.use_internal_parser, validate_transitions=FLAGS.validate_transitions ) # Normalize output. logits = F.log_softmax(output) # Calculate class accuracy. target = torch.from_numpy(y_batch).long() pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability class_acc = pred.eq(target).sum() / float(target.size(0)) # Calculate class loss. xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False))) # Optionally calculate transition loss. transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None # Extract L2 Cost l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None # Accumulate Total Loss Variable total_loss = 0.0 total_loss += xent_loss if l2_loss is not None: total_loss += l2_loss if transition_loss is not None and model.optimize_transition_loss: total_loss += transition_loss total_loss += auxiliary_loss(model) # Backward pass. total_loss.backward() # Hard Gradient Clipping clip = FLAGS.clipping_max_value for p in model.parameters(): if p.requires_grad: p.grad.data.clamp_(min=-clip, max=clip) # Learning Rate Decay if FLAGS.actively_decay_learning_rate: optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0)) # Gradient descent step. optimizer.step() end = time.time() total_time = end - start train_accumulate(model, data_manager, A, batch) A.add('class_acc', class_acc) A.add('total_tokens', total_tokens) A.add('total_time', total_time) if step % FLAGS.statistics_interval_steps == 0: progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps) progress_bar.finish() A.add('xent_cost', xent_loss.data[0]) A.add('l2_cost', l2_loss.data[0]) stats_args = train_stats(model, optimizer, A, step) logger.Log(train_str.format(**stats_args)) logger.Log(train_extra_str.format(**stats_args)) if step > 0 and step % FLAGS.eval_interval_steps == 0: for index, eval_set in enumerate(eval_iterators): acc = evaluate(model, eval_set, logger, step) if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step: best_dev_error = 1 - acc logger.Log("Checkpointing with new best dev accuracy of %f" % acc) trainer.save(best_checkpoint_path, step, best_dev_error) progress_bar.reset() if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0: logger.Log("Checkpointing.") trainer.save(standard_checkpoint_path, step, best_dev_error) progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)