def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() print("Loading training trees from {}...".format(args.train_path)) if hparams.predict_tags and args.train_path.endswith('10way.clean'): print("WARNING: The data distributed with this repository contains " "predicted part-of-speech tags only (not gold tags!) We do not " "recommend enabling predict_tags in this configuration.") train_treebank = trees.load_trees(args.train_path) if hparams.max_len_train > 0: train_treebank = [tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train] print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(args.dev_path)) dev_treebank = trees.load_trees(args.dev_path) if hparams.max_len_dev > 0: dev_treebank = [tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(tokens.START) tag_vocab.index(tokens.STOP) tag_vocab.index(tokens.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(tokens.START) word_vocab.index(tokens.STOP) word_vocab.index(tokens.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print("Initializing model...") load_path = args.load_path if load_path is not None: print(f"Loading parameters from {load_path}") info = torch_load(load_path) parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) else: parser = parse_nk.NKChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, hparams, ) print("Initializing optimizer...") trainable_parameters = [param for param in parser.parameters() if param.requires_grad] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None best_dev_processed = 0 start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path nonlocal best_dev_processed dev_start_time = time.time() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print( "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ) ) if dev_fscore.fscore > best_dev_fscore: if best_dev_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_dev_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore best_dev_model_path = "{}_dev={:.2f}".format( args.model_path_base, dev_fscore.fscore) best_dev_processed = total_processed print("Saving new best model to {}...".format(best_dev_model_path)) torch.save({ 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer' : trainer.state_dict(), }, best_dev_model_path + ".pt") for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in tqdm(range(0, len(train_parse), args.batch_size)): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] batch_num_tokens = sum(len(sentence) for sentence in batch_sentences) for subbatch_sentences, subbatch_trees in parser.split_batch(batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) if hparams.predict_tags: loss = loss[0] / len(batch_trees) + loss[1] / batch_num_tokens else: loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() # if start_index // args.batch_size + 1 == int(np.ceil(len(train_parse) / args.batch_size)): # print( # "epoch {:,} " # "batch {:,}/{:,} " # "processed {:,} " # "batch-loss {:.4f} " # "grad-norm {:.4f} " # "epoch-elapsed {} " # "total-elapsed {}".format( # epoch, # start_index // args.batch_size + 1, # int(np.ceil(len(train_parse) / args.batch_size)), # total_processed, # batch_loss_value, # grad_norm, # format_elapsed(epoch_start_time), # format_elapsed(start_time), # ) # ) if current_processed >= check_every: current_processed -= check_every # print('\nEpoch {}, weights {}'.format(epoch, parser.weighted_layer.weight.data)) check_dev() # adjust learning rate at the end of an epoch if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_fscore) if (total_processed - best_dev_processed) > ((hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_parse)): print("Terminating due to lack of improvement in dev fscore.") print("The layer weights are: ", parser.weighted_layer.weight) break
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed), flush=True) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy, flush=True) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:", flush=True) hparams.print() print("Loading training trees from {}...".format(args.train_path), flush=True) train_treebank = trees.load_trees(args.train_path) if hparams.max_len_train > 0: train_treebank = [ tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train ] print("Loaded {:,} training examples.".format(len(train_treebank)), flush=True) print("Loading development trees from {}...".format(args.dev_path), flush=True) dev_treebank = trees.load_trees(args.dev_path) if hparams.max_len_dev > 0: dev_treebank = [ tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev ] print("Loaded {:,} development examples.".format(len(dev_treebank)), flush=True) print("Processing trees for training...", flush=True) train_parse = [tree.convert() for tree in train_treebank] print("Constructing vocabularies...", flush=True) tag_vocab = vocabulary.Vocabulary() tag_vocab.index(tokens.START) tag_vocab.index(tokens.STOP) tag_vocab.index(tokens.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(tokens.START) word_vocab.index(tokens.STOP) word_vocab.index(tokens.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special)), flush=True) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print("Initializing model...", flush=True) load_path = None if load_path is not None: print(f"Loading parameters from {load_path}", flush=True) info = torch_load(load_path) parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) else: parser = parse_nk.NKChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, hparams, ) print("Initializing optimizer...", flush=True) trainable_parameters = [ param for param in parser.parameters() if param.requires_grad ] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm print("Training...", flush=True) total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None dev_efscore = None start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path nonlocal dev_efscore dev_start_time = time.time() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print(" dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ), flush=True) dev_efscore = evaluate_EDITED.Evaluate(dev_treebank, dev_predicted) print(" dev-Efscore: {}".format(dev_efscore), flush=True) # MJ - keep model with best efscore if dev_efscore.efscore > best_dev_fscore: best_dev_fscore = dev_efscore.efscore if best_dev_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_dev_model_path + ext if os.path.exists(path): print( " Removing previous model file {}...".format(path), flush=True) os.remove(path) best_dev_model_path = "{}_Edev={:.4}".format( args.model_path_base, best_dev_fscore) print(" Saving new best model to {}...".format( best_dev_model_path), flush=True) torch.save( { 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer': trainer.state_dict(), }, best_dev_model_path + ".pt") def check_hurdle(epoch, hurdle): if best_dev_fscore < hurdle: message = ("FAILURE: Epoch {} hurdle failed, stopping now!\n" "best_dev_fscore = {} < epoch{}_hurdle = {}".format( epoch, best_dev_fscore, epoch, hurdle)) print(message, flush=True) if args.results_path: print(message, file=open(args.results_path, 'w'), flush=True) sys.exit(message) for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() sum_grad_norm = 0 sum_batch_loss_value = 0 batch_processed = 0 for start_index in range(0, len(train_parse), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] for subbatch_sentences, subbatch_trees in parser.split_batch( batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss len_subbatch_trees = len(subbatch_trees) total_processed += len_subbatch_trees current_processed += len_subbatch_trees batch_processed += len_subbatch_trees grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) sum_batch_loss_value += batch_loss_value sum_grad_norm += grad_norm trainer.step() if args.print_minibatches: print(" epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time)), flush=True) if current_processed >= check_every: current_processed -= check_every check_dev() # adjust learning rate at the end of an epoch if hparams.step_decay: if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_fscore) print("Epoch {:,} " "total-processed {:,} " "average-batch-loss {:.4} " "average-grad-norm {:.4} " "total-elapsed {}".format(epoch, total_processed, sum_batch_loss_value / batch_processed, sum_grad_norm / batch_processed, format_elapsed(start_time)), flush=True) if epoch == 1: check_hurdle(epoch, args.epoch1_hurdle) elif epoch == 10: check_hurdle(epoch, args.epoch10_hurdle) if args.results_path: outf = open(args.results_path, 'w') print(dev_efscore.table(), file=outf, flush=True) else: print(dev_efscore.table(), flush=True)