def run(options): logger = get_logger() experiment_logger = ExperimentLogger() train_dataset, validation_dataset = get_train_and_validation(options) train_iterator = get_train_iterator(options, train_dataset) validation_iterator = get_validation_iterator(options, validation_dataset) embeddings = train_dataset['embeddings'] logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) logger.info('Model:') for name, p in trainer.net.named_parameters(): logger.info('{} {}'.format(name, p.shape)) if options.save_init: logger.info('Saving model (init).') trainer.save_model( os.path.join(options.experiment_path, 'model_init.pt')) if options.parse_only: run_parse(options, train_iterator, trainer, validation_iterator) sys.exit() run_train(options, train_iterator, trainer, validation_iterator)
def run_train(options, train_iterator, trainer, validation_iterator): logger = get_logger() experiment_logger = ExperimentLogger() logger.info('Running train.') seeds = generate_seeds(options.max_epoch, options.seed) step = 0 for epoch, seed in zip(range(options.max_epoch), seeds): # --- Train--- # seed = seeds[epoch] logger.info('epoch={} seed={}'.format(epoch, seed)) def myiterator(): it = train_iterator.get_iterator(random_seed=seed) count = 0 for batch_map in it: # TODO: Skip short examples (optionally). if batch_map['length'] <= 2: continue yield count, batch_map count += 1 for batch_idx, batch_map in myiterator(): if options.finetune and step >= options.finetune_after: trainer.freeze_diora() result = trainer.step(batch_map) experiment_logger.record(result) if step % options.log_every_batch == 0: experiment_logger.log_batch(epoch, step, batch_idx, batch_size=options.batch_size) # -- Periodic Checkpoints -- # if not options.multigpu or options.local_rank == 0: if step % options.save_latest == 0 and step >= options.save_after: logger.info('Saving model (periodic).') trainer.save_model( os.path.join(options.experiment_path, 'model_periodic.pt')) save_experiment( os.path.join(options.experiment_path, 'experiment_periodic.json'), step) if step % options.save_distinct == 0 and step >= options.save_after: logger.info('Saving model (distinct).') trainer.save_model( os.path.join(options.experiment_path, 'model.step_{}.pt'.format(step))) save_experiment( os.path.join(options.experiment_path, 'experiment.step_{}.json'.format(step)), step) del result step += 1 experiment_logger.log_epoch(epoch, step) if options.max_step is not None and step >= options.max_step: logger.info('Max-Step={} Quitting.'.format(options.max_step)) sys.exit()
def run_train(options, train_iterator, trainer, validation_iterator): logger = get_logger() experiment_logger = ExperimentLogger() logger.info('Running train.') seeds = generate_seeds(options.max_epoch, options.seed) step = 0 # Added now idx2word = {v: k for k, v in train_iterator.word2idx.items()} parse_predictor = CKY(net=trainer.net.diora, word2idx=train_iterator.word2idx) # Added now for epoch, seed in zip(range(options.max_epoch), seeds): # --- Train--- # # Added now precision = 0 recall = 0 total_len = 0 count_des = 0 # Added now seed = seeds[epoch] logger.info('epoch={} seed={}'.format(epoch, seed)) def myiterator(): it = train_iterator.get_iterator(random_seed=seed) count = 0 for batch_map in it: # TODO: Skip short examples (optionally). if batch_map['length'] <= 2: continue yield count, batch_map count += 1 for batch_idx, batch_map in myiterator(): if options.finetune and step >= options.finetune_after: trainer.freeze_diora() result = trainer.step(batch_map) # Added now trainer.net.eval() sentences = batch_map['sentences'] trees = parse_predictor.parse_batch(batch_map) o_list = [] for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) o = dict(example_id=example_id, tree=tr) o_list.append(o["tree"]) # print(json.dumps(o)) # print(o["tree"]) # print(batch_map["parse_tree"][ii]) if isinstance(batch_map["parse_tree"][ii], str): parse_tree_tuple = str_to_tuple( batch_map["parse_tree"][ii]) else: parse_tree_tuple = batch_map["parse_tree"][ii] o_spans = tree_to_spans(o["tree"]) batch_spans = tree_to_spans(parse_tree_tuple[0]) p, r, t = precision_and_recall(batch_spans, o_spans) precision += p recall += r total_len += t # print(precision, recall, total_len) # print(precision / total_len, recall / total_len) # print((2*precision*recall)/(total_len*(precision+recall))) trainer.net.train() # Added now experiment_logger.record(result) if step % options.log_every_batch == 0: experiment_logger.log_batch(epoch, step, batch_idx, batch_size=options.batch_size) # -- Periodic Checkpoints -- # if not options.multigpu or options.local_rank == 0: if step % options.save_latest == 0 and step >= options.save_after: logger.info('Saving model (periodic).') trainer.save_model( os.path.join(options.experiment_path, 'model_periodic.pt')) save_experiment( os.path.join(options.experiment_path, 'experiment_periodic.json'), step) if step % options.save_distinct == 0 and step >= options.save_after: logger.info('Saving model (distinct).') trainer.save_model( os.path.join(options.experiment_path, 'model.step_{}.pt'.format(step))) save_experiment( os.path.join(options.experiment_path, 'experiment.step_{}.json'.format(step)), step) del result step += 1 # Added now print(precision, recall, total_len) print(precision / total_len, recall / total_len) print(count_des) # Added now experiment_logger.log_epoch(epoch, step) if options.max_step is not None and step >= options.max_step: logger.info('Max-Step={} Quitting.'.format(options.max_step)) sys.exit()