def main(): # Argument passing/parsing args, model_args = config_utils.initialize_argparser( MODELS, _command_args, custom_argparsers.DialogArgumentParser) hparams, hparams_dict = config_utils.create_or_load_hparams( args, model_args, args.cfg) pprint(hparams_dict) if hparams.test_mode == 'wow': os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/wow_pretrained'): fname = 'wow_pretrained.zip' gd_id = '1lkF1QENr45j0vl-Oja3wEiqkxoNTxkXT' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/wow_pretrained', 'ckpt-46070') else: raise ValueError("Only 'wow' is currently supported") # Set environment variables & gpus set_logger() set_gpus(hparams.gpus) set_tcmalloc() gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus, 'GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set random seed #tf.random.set_seed(hparams.random_seed) #np.random.seed(hparams.random_seed) #random.seed(hparams.random_seed) # Set gpu assert hparams.num_gpus == 1 mirrored_strategy = None # Make dataset reader os.makedirs(hparams.cache_dir, exist_ok=True) reader = WowDatasetReader( hparams.batch_size, hparams.num_epochs, buffer_size=hparams.buffer_size, bucket_width=hparams.bucket_width, max_length=hparams.max_length, max_episode_length=hparams.max_episode_length, max_knowledge=hparams.max_knowledge, knowledge_truncate=hparams.knowledge_truncate, cache_dir=hparams.cache_dir, bert_dir=hparams.bert_dir, ) train_dataset, iters_in_train = reader.read('train', mirrored_strategy) test_dataset, iters_in_test = reader.read('test', mirrored_strategy) vocabulary = reader.vocabulary # Build model & optimizer & trainer model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) trainer = Trainer(model, optimizer, mirrored_strategy, hparams.enable_function, WowDatasetReader.remove_pad) # Setup checkpoint global_step = tf.compat.v1.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) train_example = next(iter(train_dataset)) _ = trainer.train_step(train_example) checkpoint.restore(ckpt_fname) # Load retriever and input processor dictionary = reader._dictionary tokenize_fn = lambda x: [data_vocab.BERT_CLS_ID] \ + dictionary.convert_tokens_to_ids(dictionary.tokenize(x)) \ + [data_vocab.BERT_SEP_ID] input_processor = InteractiveInputProcessor(tokenize_fn, 5) # Compile graph colorlog.info("Compile model") dummy_input = input_processor.get_dummy_input() for _ in trange(5, ncols=70): trainer.test_step(dummy_input) # Module for interactive mode wiki_tfidf_retriever = WikiTfidfRetriever(hparams.cache_dir) topics_generator = TopicsGenerator(hparams.cache_dir) interactive_world = InteractiveWorld(responder=trainer, input_processor=input_processor, wiki_retriever=wiki_tfidf_retriever, topics_generator=topics_generator) # Loop! while True: interactive_world.run() interactive_world.reset()
def main(): # Argument passing/parsing args, model_args = config_utils.initialize_argparser( MODELS, _command_args, custom_argparsers.DialogArgumentParser) hparams, hparams_dict = config_utils.create_or_load_hparams( args, model_args, args.cfg) pprint(hparams_dict) # Set environment variables & gpus set_logger() set_gpus(hparams.gpus) set_tcmalloc() gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus, 'GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set random seed tf.random.set_seed(hparams.random_seed) np.random.seed(hparams.random_seed) random.seed(hparams.random_seed) # For multi-gpu if hparams.num_gpus > 1: mirrored_strategy = tf.distribute.MirroredStrategy() # NCCL will be used as default else: mirrored_strategy = None # Download BERT pretrained model if not os.path.exists(hparams.bert_dir): os.makedirs(hparams.bert_dir) fname = 'uncased_L-12_H-768_A-12.zip' gd_id = '17rfV9CleFBwwfS7m5Yd72vvxdPLWBHl6' download_from_google_drive(gd_id, os.path.join(hparams.bert_dir, fname)) unzip(hparams.bert_dir, fname) # Make dataset reader os.makedirs(hparams.cache_dir, exist_ok=True) if hparams.data_name == "wizard_of_wikipedia": reader_cls = WowDatasetReader elif hparams.data_name == "holle": reader_cls = HolleDatasetReader else: raise ValueError("data_name must be one of 'wizard_of_wikipedia' and 'holle'") reader = reader_cls( hparams.batch_size, hparams.num_epochs, buffer_size=hparams.buffer_size, bucket_width=hparams.bucket_width, max_length=hparams.max_length, max_episode_length=hparams.max_episode_length, max_knowledge=hparams.max_knowledge, knowledge_truncate=hparams.knowledge_truncate, cache_dir=hparams.cache_dir, bert_dir=hparams.bert_dir, ) train_dataset, iters_in_train = reader.read('train', mirrored_strategy) test_dataset, iters_in_test = reader.read('test', mirrored_strategy) if hparams.data_name == 'wizard_of_wikipedia': unseen_dataset, iters_in_unseen = reader.read('test_unseen', mirrored_strategy) vocabulary = reader.vocabulary # Build model & optimizer & trainer if mirrored_strategy: with mirrored_strategy.scope(): model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) else: model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) trainer = Trainer(model, optimizer, mirrored_strategy, hparams.enable_function, WowDatasetReader.remove_pad) # misc (tensorboard, checkpoints) file_writer = tf.summary.create_file_writer(hparams.checkpoint_dir) file_writer.set_as_default() global_step = tf.compat.v1.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=hparams.checkpoint_dir, max_to_keep=hparams.max_to_keep) checkpoint_tracker = CheckpointTracker( hparams.checkpoint_dir, max_to_keep=BEST_N_CHECKPOINTS) # Main loop! train_dataset_iter = iter(train_dataset) for epoch in range(hparams.num_epochs): print(hparams.checkpoint_dir) base_description = f"(Train) Epoch {epoch}, GPU {hparams.gpus}" train_tqdm = trange(iters_in_train, ncols=120, desc=base_description) for current_step in train_tqdm: example = next(train_dataset_iter) global_step.assign_add(1) _global_step = int(global_step) # Train output_dict = trainer.train_step(example) # Print model if _global_step == 1: model.print_model() loss_str = str(output_dict['loss'].numpy()) train_tqdm.set_description(f"{base_description}, Loss {loss_str}") with file_writer.as_default(): if _global_step % int(hparams.logging_step) == 0: tf.summary.histogram('train/vocab', output_dict['sample_ids'], step=_global_step) tf.summary.scalar('train/loss', output_dict['loss'], step=_global_step) tf.summary.scalar('train/gen_loss', output_dict['gen_loss'], step=_global_step) tf.summary.scalar('train/knowledge_loss', output_dict['knowledge_loss'], step=_global_step) tf.summary.scalar('train/kl_loss', output_dict['kl_loss'], step=_global_step) # Test if _global_step % int(iters_in_train * hparams.evaluation_epoch) == 0: checkpoint_manager.save(global_step) test_loop_outputs = trainer.test_loop(test_dataset, iters_in_test, epoch, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_loop_outputs = trainer.test_loop(unseen_dataset, iters_in_unseen, epoch, 'unseen') test_summaries, log_dict = run_wow_evaluation( test_loop_outputs, hparams.checkpoint_dir, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_summaries, unseen_log_dict = run_wow_evaluation( unseen_loop_outputs, hparams.checkpoint_dir, 'unseen') # Logging tqdm.write(colorful.bold_green("seen").styled_string) tqdm.write(colorful.bold_red(pformat(log_dict)).styled_string) if hparams.data_name == 'wizard_of_wikipedia': tqdm.write(colorful.bold_green("unseen").styled_string) tqdm.write(colorful.bold_red(pformat(unseen_log_dict)).styled_string) with file_writer.as_default(): for family, test_summary in test_summaries.items(): for key, value in test_summary.items(): tf.summary.scalar(f'{family}/{key}', value, step=_global_step) if hparams.data_name == 'wizard_of_wikipedia': for family, unseen_summary in unseen_summaries.items(): for key, value in unseen_summary.items(): tf.summary.scalar(f'{family}/{key}', value, step=_global_step) if hparams.keep_best_checkpoint: current_score = log_dict["rouge1"] checkpoint_tracker.update(current_score, _global_step)
def main(): # Argument passing/parsing args, model_args = config_utils.initialize_argparser( MODELS, _command_args, custom_argparsers.DialogArgumentParser) hparams, hparams_dict = config_utils.create_or_load_hparams( args, model_args, args.cfg) pprint(hparams_dict) if hparams.test_mode == 'wow': os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/wow_pretrained'): fname = 'wow_pretrained.zip' gd_id = '1lkF1QENr45j0vl-Oja3wEiqkxoNTxkXT' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/wow_pretrained', 'ckpt-46070') elif hparams.test_mode == "holle_1": os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/holle_pretrained_1'): fname = 'holle_pretrained_1.zip' gd_id = '1o1-Gv5PScxlSzxW6DyZnSp3gDI5zXOhh' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/holle_pretrained_1', 'ckpt-1th-best') elif hparams.test_mode == "holle_2": os.makedirs('./tmp', exist_ok=True) if not os.path.exists('tmp/holle_pretrained_2'): fname = 'holle_pretrained_2.zip' gd_id = '13FkCjuC0aBEenlSf-NAAgOfoWVPhqFSc' colorlog.info(f"Download pretrained checkpoint {fname}") download_from_google_drive(gd_id, os.path.join('tmp', fname)) unzip('tmp', fname) ckpt_fname = os.path.join('tmp/holle_pretrained_2', 'ckpt-1th-best') else: raise ValueError("'wow' and 'holle' is currently supported") # Set environment variables & gpus set_logger() set_gpus(hparams.gpus) set_tcmalloc() gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus, 'GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Set random seed tf.random.set_seed(hparams.random_seed) np.random.seed(hparams.random_seed) random.seed(hparams.random_seed) # For multi-gpu if hparams.num_gpus > 1: mirrored_strategy = tf.distribute.MirroredStrategy( ) # NCCL will be used as default else: mirrored_strategy = None # Download BERT pretrained model if not os.path.exists(hparams.bert_dir): os.makedirs(hparams.bert_dir) fname = 'uncased_L-12_H-768_A-12.zip' gd_id = '17rfV9CleFBwwfS7m5Yd72vvxdPLWBHl6' download_from_google_drive(gd_id, os.path.join(hparams.bert_dir, fname)) unzip(hparams.bert_dir, fname) # Make dataset reader os.makedirs(hparams.cache_dir, exist_ok=True) if hparams.data_name == 'wizard_of_wikipedia': reader_cls = WowDatasetReader elif hparams.data_name == 'holle': reader_cls = HolleDatasetReader else: raise ValueError( "data_name must be one of 'wizard_of_wikipedia' and 'holle'") reader = reader_cls( hparams.batch_size, hparams.num_epochs, buffer_size=hparams.buffer_size, bucket_width=hparams.bucket_width, max_length=hparams.max_length, max_episode_length=hparams.max_episode_length, max_knowledge=hparams.max_knowledge, knowledge_truncate=hparams.knowledge_truncate, cache_dir=hparams.cache_dir, bert_dir=hparams.bert_dir, ) train_dataset, iters_in_train = reader.read('train', mirrored_strategy) test_dataset, iters_in_test = reader.read('test', mirrored_strategy) if hparams.data_name == 'wizard_of_wikipedia': unseen_dataset, iters_in_unseen = reader.read('test_unseen', mirrored_strategy) vocabulary = reader.vocabulary # Build model & optimizer & trainer if mirrored_strategy: with mirrored_strategy.scope(): model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) else: model = MODELS[hparams.model](hparams, vocabulary) optimizer = tf.keras.optimizers.Adam(learning_rate=hparams.init_lr, clipnorm=hparams.clipnorm) trainer = Trainer(model, optimizer, mirrored_strategy, hparams.enable_function, WowDatasetReader.remove_pad) # Setup checkpoint global_step = tf.compat.v1.train.get_or_create_global_step() checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model, optimizer_step=global_step) # Load train_example = next(iter(train_dataset)) _ = trainer.train_step(train_example) #checkpoint.restore(ckpt_fname).assert_consumed() #checkpoint.restore(ckpt_fname).expect_partial() checkpoint.restore(ckpt_fname) # Test test_loop_outputs = trainer.test_loop(test_dataset, iters_in_test, 0, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_loop_outputs = trainer.test_loop(unseen_dataset, iters_in_unseen, 0, 'unseen') test_summaries, log_dict = run_wow_evaluation(test_loop_outputs, hparams.checkpoint_dir, 'seen') if hparams.data_name == 'wizard_of_wikipedia': unseen_summaries, unseen_log_dict = run_wow_evaluation( unseen_loop_outputs, hparams.checkpoint_dir, 'unseen') # Logging tqdm.write(colorful.bold_green("seen").styled_string) tqdm.write(colorful.bold_red(pformat(log_dict)).styled_string) if hparams.data_name == 'wizard_of_wikipedia': tqdm.write(colorful.bold_green("unseen").styled_string) tqdm.write(colorful.bold_red(pformat(unseen_log_dict)).styled_string)