def test_preprocess_with_monolingual_with_tgt_chars(self): """ This is just a correctness test to make sure no errors are thrown when all the required args are passed. Actual parsing code is tested by test_data.py """ args = self.get_common_data_args_namespace() args.task = constants.SEMI_SUPERVISED_TASK args.train_mono_source_text_file = self.source_text_file args.train_mono_target_text_file = self.target_text_file args.arch = "char_aware_hybrid" args.char_source_max_vocab_size = 30 args.char_target_max_vocab_size = 30 args.char_source_vocab_file = test_utils.make_temp_file() args.char_target_vocab_file = test_utils.make_temp_file() preprocess.preprocess_corpora(args) for file_type in ( "train_source_binary_path", "train_target_binary_path", "eval_source_binary_path", "eval_target_binary_path", "train_mono_source_binary_path", "train_mono_target_binary_path", ): file_path = getattr(args, file_type) assert file_path and os.path.isfile(file_path) assert file_path.endswith(".npz")
def main(args): # We preprocess the data (generating vocab files and binarized data files # if needed) outside of the train clones to prevent them from having to # wait while the master clone is doing this. preprocess.preprocess_corpora(args) # Set distributed training parameters for a single node. args.distributed_world_size = torch.cuda.device_count() args.distributed_init_method = f"tcp://localhost:{random.randint(10000, 20000)}" if args.distributed_world_size == 1: return single_process_main(args) mp = multiprocessing.get_context("spawn") # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for i in range(args.distributed_world_size): args.distributed_rank = i args.device_id = i procs.append( mp.Process(target=run, args=(args, error_queue), daemon=True)) procs[i].start() error_handler.add_child(procs[i].pid) for p in procs: p.join()
def main(args, trainer_class=Trainer, **train_step_kwargs): # We preprocess the data (generating vocab files and binarized data files # if needed) outside of the train processes to prevent them from having to # wait while the master process is doing this. preprocess.preprocess_corpora(args) if args.distributed_world_size == 1: single_process_main(args, trainer_class, **train_step_kwargs) else: spawn_context, _ = multi_process_main(args=args, start_rank=0) while not spawn_context.join(): pass
def main(args): # We preprocess the data (generating vocab files and binarized data files # if needed) outside of the train processes to prevent them from having to # wait while the master process is doing this. preprocess.preprocess_corpora(args) if args.distributed_world_size == 1: single_process_main(args) else: processes, error_handler, _ = multi_process_main( args=args, use_output_queue=False, start_rank=0) for p in processes: p.join()
def test_preprocess(self): """ This is just a correctness test to make sure no errors are thrown when all the required args are passed. Actual parsing code is tested by test_data.py """ args = self.get_common_data_args_namespace() preprocess.preprocess_corpora(args) for file_type in ( "train_source_binary_path", "train_target_binary_path", "eval_source_binary_path", "eval_target_binary_path", ): file = getattr(args, file_type) assert file and os.path.isfile(file) assert file.endswith(".npz")
def test_preprocess_with_monolingual(self): """ This is just a correctness test to make sure no errors are thrown when all the required args are passed. Actual parsing code is tested by test_data.py """ args = self.get_common_data_args_namespace() args.task = "pytorch_translate_semisupervised" args.train_mono_source_text_file = self.source_text_file args.train_mono_target_text_file = self.target_text_file preprocess.preprocess_corpora(args) for file_type in ( "train_source_binary_path", "train_target_binary_path", "eval_source_binary_path", "eval_target_binary_path", "train_mono_source_binary_path", "train_mono_target_binary_path", ): file_path = getattr(args, file_type) assert file_path and os.path.isfile(file_path) assert file_path.endswith(".npz")
def main(args, trainer_class=Trainer, **train_step_kwargs): # We preprocess the data (generating vocab files and binarized data files # if needed) outside of the train processes to prevent them from having to # wait while the master process is doing this. preprocess.preprocess_corpora(args) if args.distributed_world_size == 1: single_process_main(args, trainer_class, **train_step_kwargs) else: spawn_context, output_queue = multi_process_main(args=args, start_rank=0) while not spawn_context.join(timeout=30): # Periodically clears the output queue to ensure that the processes # don't deadlock due to queue buffer being full. This is also # necessary to ensure that processes join correctly, since a process # may not terminate until all items it put on the queue have been # consumed (per # https://docs.python.org/3/library/multiprocessing.html#all-start-methods). try: while True: output_queue.get_nowait() except queue.Empty: pass