def main(): # Argument passing/parsing args, model_args = config_utils.initialize_argparser( MODELS, _command_args, custom_argparsers.DialogArgumentParser) hparams, hparams_dict = config_utils.create_or_load_hparams( args, model_args, args.cfg) pprint(hparams_dict) # Set environment variables & gpus set_logger() set_gpus(hparams.gpus) set_tcmalloc() gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus, 'GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set random seed tf.random.set_seed(hparams.random_seed) np.random.seed(hparams.random_seed) random.seed(hparams.random_seed) # For multi-gpu if hparams.num_gpus > 1: mirrored_strategy = tf.distribute.MirroredStrategy() # NCCL will be used as default else: mirrored_strategy = None # Download BERT pretrained model if not os.path.exists(hparams.bert_dir): os.makedirs(hparams.bert_dir) fname = 'uncased_L-12_H-768_A-12.zip' gd_id = '17rfV9CleFBwwfS7m5Yd72vvxdPLWBHl6' download_from_google_drive(gd_id, os.path.join(hparams.bert_dir, fname)) unzip(hparams.bert_dir, fname) # Make dataset reader os.makedirs(hparams.cache_dir, exist_ok=True) if hparams.data_name == "wizard_of_wikipedia": reader_cls = WowDatasetReader elif hparams.data_name == "holle": reader_cls = HolleDatasetReader else: raise ValueError("data_name must be one of 'wizard_of_wikipedia' and 'holle'") reader = reader_cls( hparams.batch_size, hparams.num_epochs, buffer_size=hparams.buffer_size, bucket_width=hparams.bucket_width, max_length=hparams.max_length, max_episode_length=hparams.max_episode_length, max_knowledge=hparams.max_knowledge, knowledge_truncate=hparams.knowledge_truncate, cache_dir=hparams.cache_dir, bert_dir=hparams.bert_dir, ) train_dataset, iters_in_train = reader.read('train', mirrored_strategy) test_dataset, iters_in_test = reader.read('test', mirrored_strategy) if hparams.data_name == 'wizard_of_wikipedia': unseen_dataset, iters_in_unseen = reader.read('test_unseen', mirrored_strategy) vocabulary = reader.vocabulary # Build model & optimizer & trainer if mirrored_strategy: with mirrored_strategy.scope(): model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) else: model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) trainer = Trainer(model, optimizer, mirrored_strategy, hparams.enable_function, WowDatasetReader.remove_pad) # misc (tensorboard, checkpoints) file_writer = tf.summary.create_file_writer(hparams.checkpoint_dir) file_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=hparams.checkpoint_dir, max_to_keep=hparams.max_to_keep) checkpoint_tracker = CheckpointTracker( hparams.checkpoint_dir, max_to_keep=BEST_N_CHECKPOINTS) # Main loop! train_dataset_iter = iter(train_dataset) for epoch in range(hparams.num_epochs): print(hparams.checkpoint_dir) base_description = f"(Train) Epoch {epoch}, GPU {hparams.gpus}" train_tqdm = trange(iters_in_train, ncols=120, desc=base_description) for current_step in train_tqdm: example = next(train_dataset_iter) global_step.assign_add(1) _global_step = int(global_step) # Train output_dict = trainer.train_step(example) # Print model if _global_step == 1: model.print_model() loss_str = str(output_dict['loss'].numpy()) train_tqdm.set_description(f"{base_description}, Loss {loss_str}") with file_writer.as_default(): if _global_step % int(hparams.logging_step) == 0: tf.summary.histogram('train/vocab', output_dict['sample_ids'], step=_global_step) tf.summary.scalar('train/loss', output_dict['loss'], step=_global_step) tf.summary.scalar('train/gen_loss', output_dict['gen_loss'], step=_global_step) tf.summary.scalar('train/knowledge_loss', output_dict['knowledge_loss'], step=_global_step) tf.summary.scalar('train/kl_loss', output_dict['kl_loss'], step=_global_step) # Test if _global_step % int(iters_in_train * hparams.evaluation_epoch) == 0: checkpoint_manager.save(global_step) test_loop_outputs = trainer.test_loop(test_dataset, iters_in_test, epoch, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_loop_outputs = trainer.test_loop(unseen_dataset, iters_in_unseen, epoch, 'unseen') test_summaries, log_dict = run_wow_evaluation( test_loop_outputs, hparams.checkpoint_dir, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_summaries, unseen_log_dict = run_wow_evaluation( unseen_loop_outputs, hparams.checkpoint_dir, 'unseen') # Logging tqdm.write(colorful.bold_green("seen").styled_string) tqdm.write(colorful.bold_red(pformat(log_dict)).styled_string) if hparams.data_name == 'wizard_of_wikipedia': tqdm.write(colorful.bold_green("unseen").styled_string) tqdm.write(colorful.bold_red(pformat(unseen_log_dict)).styled_string) with file_writer.as_default(): for family, test_summary in test_summaries.items(): for key, value in test_summary.items(): tf.summary.scalar(f'{family}/{key}', value, step=_global_step) if hparams.data_name == 'wizard_of_wikipedia': for family, unseen_summary in unseen_summaries.items(): for key, value in unseen_summary.items(): tf.summary.scalar(f'{family}/{key}', value, step=_global_step) if hparams.keep_best_checkpoint: current_score = log_dict["rouge1"] checkpoint_tracker.update(current_score, _global_step)
def main(): # Argument passing/parsing args, model_args = config_utils.initialize_argparser( MODELS, _command_args, custom_argparsers.DialogArgumentParser) hparams, hparams_dict = config_utils.create_or_load_hparams( args, model_args, args.cfg) pprint(hparams_dict) if hparams.test_mode == 'wow': os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/wow_pretrained'): fname = 'wow_pretrained.zip' gd_id = '1lkF1QENr45j0vl-Oja3wEiqkxoNTxkXT' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/wow_pretrained', 'ckpt-46070') elif hparams.test_mode == "holle_1": os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/holle_pretrained_1'): fname = 'holle_pretrained_1.zip' gd_id = '1o1-Gv5PScxlSzxW6DyZnSp3gDI5zXOhh' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/holle_pretrained_1', 'ckpt-1th-best') elif hparams.test_mode == "holle_2": os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/holle_pretrained_2'): fname = 'holle_pretrained_2.zip' gd_id = '13FkCjuC0aBEenlSf-NAAgOfoWVPhqFSc' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/holle_pretrained_2', 'ckpt-1th-best') else: raise ValueError("'wow' and 'holle' is currently supported") # Set environment variables & gpus set_logger() set_gpus(hparams.gpus) set_tcmalloc() gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus, 'GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set random seed tf.random.set_seed(hparams.random_seed) np.random.seed(hparams.random_seed) random.seed(hparams.random_seed) # For multi-gpu if hparams.num_gpus > 1: mirrored_strategy = tf.distribute.MirroredStrategy( ) # NCCL will be used as default else: mirrored_strategy = None # Download BERT pretrained model if not os.path.exists(hparams.bert_dir): os.makedirs(hparams.bert_dir) fname = 'uncased_L-12_H-768_A-12.zip' gd_id = '17rfV9CleFBwwfS7m5Yd72vvxdPLWBHl6' download_from_google_drive(gd_id, os.path.join(hparams.bert_dir, fname)) unzip(hparams.bert_dir, fname) # Make dataset reader os.makedirs(hparams.cache_dir, exist_ok=True) if hparams.data_name == 'wizard_of_wikipedia': reader_cls = WowDatasetReader elif hparams.data_name == 'holle': reader_cls = HolleDatasetReader else: raise ValueError( "data_name must be one of 'wizard_of_wikipedia' and 'holle'") reader = reader_cls( hparams.batch_size, hparams.num_epochs, buffer_size=hparams.buffer_size, bucket_width=hparams.bucket_width, max_length=hparams.max_length, max_episode_length=hparams.max_episode_length, max_knowledge=hparams.max_knowledge, knowledge_truncate=hparams.knowledge_truncate, cache_dir=hparams.cache_dir, bert_dir=hparams.bert_dir, ) train_dataset, iters_in_train = reader.read('train', mirrored_strategy) test_dataset, iters_in_test = reader.read('test', mirrored_strategy) if hparams.data_name == 'wizard_of_wikipedia': unseen_dataset, iters_in_unseen = reader.read('test_unseen', mirrored_strategy) vocabulary = reader.vocabulary # Build model & optimizer & trainer if mirrored_strategy: with mirrored_strategy.scope(): model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) else: model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) trainer = Trainer(model, optimizer, mirrored_strategy, hparams.enable_function, WowDatasetReader.remove_pad) # Setup checkpoint global_step = tf.compat.v1.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) # Load train_example = next(iter(train_dataset)) _ = trainer.train_step(train_example) #checkpoint.restore(ckpt_fname).assert_consumed() #checkpoint.restore(ckpt_fname).expect_partial() checkpoint.restore(ckpt_fname) # Test test_loop_outputs = trainer.test_loop(test_dataset, iters_in_test, 0, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_loop_outputs = trainer.test_loop(unseen_dataset, iters_in_unseen, 0, 'unseen') test_summaries, log_dict = run_wow_evaluation(test_loop_outputs, hparams.checkpoint_dir, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_summaries, unseen_log_dict = run_wow_evaluation( unseen_loop_outputs, hparams.checkpoint_dir, 'unseen') # Logging tqdm.write(colorful.bold_green("seen").styled_string) tqdm.write(colorful.bold_red(pformat(log_dict)).styled_string) if hparams.data_name == 'wizard_of_wikipedia': tqdm.write(colorful.bold_green("unseen").styled_string) tqdm.write(colorful.bold_red(pformat(unseen_log_dict)).styled_string)
def main(): # Argument passing/parsing args, model_args = config_utils.initialize_argparser( MODELS, _command_args, custom_argparsers.DialogArgumentParser) hparams, hparams_dict = config_utils.create_or_load_hparams( args, model_args, args.cfg) pprint(hparams_dict) if hparams.test_mode == 'wow': os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/wow_pretrained'): fname = 'wow_pretrained.zip' gd_id = '1lkF1QENr45j0vl-Oja3wEiqkxoNTxkXT' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/wow_pretrained', 'ckpt-46070') else: raise ValueError("Only 'wow' is currently supported") # Set environment variables & gpus set_logger() set_gpus(hparams.gpus) set_tcmalloc() gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus, 'GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set random seed #tf.random.set_seed(hparams.random_seed) #np.random.seed(hparams.random_seed) #random.seed(hparams.random_seed) # Set gpu assert hparams.num_gpus == 1 mirrored_strategy = None # Make dataset reader os.makedirs(hparams.cache_dir, exist_ok=True) reader = WowDatasetReader( hparams.batch_size, hparams.num_epochs, buffer_size=hparams.buffer_size, bucket_width=hparams.bucket_width, max_length=hparams.max_length, max_episode_length=hparams.max_episode_length, max_knowledge=hparams.max_knowledge, knowledge_truncate=hparams.knowledge_truncate, cache_dir=hparams.cache_dir, bert_dir=hparams.bert_dir, ) train_dataset, iters_in_train = reader.read('train', mirrored_strategy) test_dataset, iters_in_test = reader.read('test', mirrored_strategy) vocabulary = reader.vocabulary # Build model & optimizer & trainer model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) trainer = Trainer(model, optimizer, mirrored_strategy, hparams.enable_function, WowDatasetReader.remove_pad) # Setup checkpoint global_step = tf.compat.v1.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) train_example = next(iter(train_dataset)) _ = trainer.train_step(train_example) checkpoint.restore(ckpt_fname) # Load retriever and input processor dictionary = reader._dictionary tokenize_fn = lambda x: [data_vocab.BERT_CLS_ID] \ + dictionary.convert_tokens_to_ids(dictionary.tokenize(x)) \ + [data_vocab.BERT_SEP_ID] input_processor = InteractiveInputProcessor(tokenize_fn, 5) # Compile graph colorlog.info("Compile model") dummy_input = input_processor.get_dummy_input() for _ in trange(5, ncols=70): trainer.test_step(dummy_input) # Module for interactive mode wiki_tfidf_retriever = WikiTfidfRetriever(hparams.cache_dir) topics_generator = TopicsGenerator(hparams.cache_dir) interactive_world = InteractiveWorld(responder=trainer, input_processor=input_processor, wiki_retriever=wiki_tfidf_retriever, topics_generator=topics_generator) # Loop! while True: interactive_world.run() interactive_world.reset()
def main(args): cfg = create_or_load_hparams(args) set_logger() # Build graph train_graph = tf.Graph() test_graph = tf.Graph() train_model, iters_in_train, train_iterator_init = create_graph( cfg, train_graph, "train") test_model, iters_in_test, test_iterator_init = create_graph( cfg, test_graph, "test") # Build session config_proto = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True, per_process_gpu_memory_fraction=0.95), allow_soft_placement=True, log_device_placement=False) train_session = tf.Session(config=config_proto, graph=train_graph) test_session = tf.Session(config=config_proto, graph=test_graph) # Load pretrained word vectors pretrained_word_vector = load_pretrained_word_vector(cfg) # Initialize models/vocab/iterator colorlog.info("Initialize models, vocab, iterator") with train_graph.as_default(): loaded_train_model, is_fresh = create_or_load_model( train_model, cfg.checkpoint_dir, train_session, "train") if is_fresh and pretrained_word_vector is not None: _ = loaded_train_model.apply_word_vector(train_graph, train_session, pretrained_word_vector) train_session.run(train_iterator_init) # Summary writer summary_writer = tf.summary.FileWriter(cfg.checkpoint_dir, train_graph) # Print Hparams cfg_dict = json.loads(cfg.to_json()) for key, value in cfg_dict.items(): print(key, ":", value) # Add Hparams to Tensorboard with train_graph.as_default(): config_summary = tf.summary.text( 'TrainConfig', tf.convert_to_tensor(dict_to_matrix(cfg_dict))) summary_writer.add_summary(config_summary.eval(session=train_session)) # Build graph in order to add examples to tensorboard with test_graph.as_default(): examples_placeholder = tf.placeholder(tf.string, [None, 2]) examples_summary = tf.summary.text("GT and Pred", examples_placeholder) # Build auxiliary checkpoint saving directory os.makedirs(os.path.join(cfg.checkpoint_dir, 'auxiliary')) colorlog.info("Start training") current_epoch = 0 while True: current_epoch += 1 # Train tqdm.write(cfg.other_info) for current_step in tqdm(range(iters_in_train), desc="Epoch {} (train), {}, {}".format( current_epoch, cfg.model_name, cfg.data_name), ncols=50): step_result = loaded_train_model.train(train_session) global_step = step_result['global_step'] summaries = step_result['summaries'] # Logging if (global_step + 1) % 20 == 0: for summary in summaries: summary_writer.add_summary(summary, global_step) # Test if (global_step + 1) % int( iters_in_train * cfg.evaluation_epoch) == 0: # Save checkpoint loaded_train_model.saver.save( train_session, os.path.join(cfg.checkpoint_dir, "model.ckpt"), global_step=loaded_train_model.global_step) # Load checkpoint with test_graph.as_default(): loaded_test_model, _ = create_or_load_model( test_model, cfg.checkpoint_dir, test_session, "test") # Run test test_session.run(test_iterator_init) test_steps = min(cfg.num_test_steps, iters_in_test) test_summary, log_dict, examples = run_evaluation( loaded_test_model, test_session, current_epoch, test_steps, cfg, global_step) summary_writer.add_summary(test_summary, global_step) # Add examples to tensorboard examples_summary_str = test_session.run( examples_summary, feed_dict={examples_placeholder: examples}) summary_writer.add_summary(examples_summary_str, global_step) tqdm.write(cfg.other_info) if current_epoch == cfg.end_epoch: break