def __call__(self): serialize_dir = self.config["serialize_dir"] num_epoches = self.config["num_epoches"] vocabulary_dict = self.build_vocabulary() token_vocabulary = vocabulary_dict["token_vocabulary"] category_vocabulary = vocabulary_dict["category_vocabulary"] label_vocabulary = vocabulary_dict["label_vocabulary"] model = self.build_model(token_vocabulary=token_vocabulary, category_vocabulary=category_vocabulary, label_vocabulary=label_vocabulary) loss = ACSALoss() label_decoder = ACSALabelDecoder(label_vocabulary=label_vocabulary) metrics = ACSAModelMetric(label_decoder=label_decoder) optimizer_factory = ACSAOptimizerFactory(config=self.config) patient = self.config["patient"] num_check_point_keep = self.config["num_check_point_keep"] trainer = Trainer(serialize_dir=serialize_dir, num_epoch=num_epoches, model=model, loss=loss, metrics=metrics, patient=patient, num_check_point_keep=num_check_point_keep, optimizer_factory=optimizer_factory) training_dataset_file_path = self.config["training_dataset_file_path"] validation_dataset_file_path = self.config[ "validation_dataset_file_path"] train_dataset = ACSASemEvalDataset(training_dataset_file_path) batch_size = self.config["batch_size"] model_collate = ACSAModelCollate( token_vocabulary=token_vocabulary, category_vocabulary=category_vocabulary, label_vocabulary=label_vocabulary) train_data_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size, num_workers=0, collate_fn=model_collate) validation_dataset = ACSASemEvalDataset(validation_dataset_file_path) validation_data_loader = DataLoader(dataset=validation_dataset, shuffle=False, batch_size=batch_size, num_workers=0, collate_fn=model_collate) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader)
def _start(self, rank: Optional[int], world_size: int, device: torch.device) -> None: is_distributed = rank is not None trainer = Trainer( serialize_dir=self.config.serialize_dir, num_epoch=self.config.num_epoch, model=self.config.model, loss=self.config.loss, metrics=self.config.metric, optimizer_factory=self.config.optimizer, lr_scheduler_factory=self.config.lr_scheduler, grad_rescaled=self.config.grad_rescaled, patient=self.config.patient, num_check_point_keep=self.config.num_check_point_keep, device=device, is_distributed=is_distributed, distributed_data_parallel_parameter=self.config. distributed_data_parallel_parameter) train_sampler = None if is_distributed: train_sampler = DistributedSampler( dataset=self.config.training_dataset, shuffle=False) train_data_loader = DataLoader(dataset=self.config.training_dataset, batch_size=self.config.train_batch_size, shuffle=(train_sampler is None), num_workers=0, collate_fn=self.config.model_collate, sampler=train_sampler) validation_sampler = None if is_distributed: validation_sampler = DistributedSampler( dataset=self.config.training_dataset, shuffle=False) validation_data_loader = DataLoader( dataset=self.config.validation_dataset, batch_size=self.config.test_batch_size, shuffle=False, num_workers=0, collate_fn=self.config.model_collate, sampler=validation_sampler) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader)
def _run_train(cuda_devices: List[str] = None): serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load") if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) model = ModelDemo() optimizer_factory = _DemoOptimizerFactory() loss = _DemoLoss() metric = _DemoMetric() trainer = Trainer(num_epoch=100, model=model, loss=loss, metrics=metric, optimizer_factory=optimizer_factory, serialize_dir=serialize_dir, patient=20, num_check_point_keep=25, cuda_devices=cuda_devices) train_dataset = _DemoDataset() train_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, shuffle=False, num_workers=0) validation_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, shuffle=False, num_workers=0) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) expect_model_state_dict = json.loads(json2str(trainer.model.state_dict())) expect_optimizer_state_dict = json.loads( json2str(trainer.optimizer.state_dict())) expect_current_epoch = trainer.current_epoch expect_num_epoch = trainer.num_epoch expect_metric = trainer.metrics.metric[0] expect_metric_tracker = json.loads(json2str(trainer.metric_tracker)) trainer.load_checkpoint(serialize_dir=serialize_dir) loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict())) loaded_optimizer_state_dict = json.loads( json2str(trainer.optimizer.state_dict())) current_epoch = trainer.current_epoch num_epoch = trainer.num_epoch metric = trainer.metrics.metric[0] metric_tracker = json.loads(json2str(trainer.metric_tracker)) ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict) ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict) ASSERT.assertEqual(expect_current_epoch, current_epoch) ASSERT.assertEqual(expect_num_epoch, num_epoch) ASSERT.assertDictEqual(expect_metric, metric) ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)
def _run_train(device: torch.device, is_distributed: bool): serialize_dir = os.path.join(ROOT_PATH, "data/easytext/tests/trainer/save_and_load") if is_distributed: if TorchDist.get_rank() == 0: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) TorchDist.barrier() else: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) model = ModelDemo() optimizer_factory = _DemoOptimizerFactory() loss = _DemoLoss() metric = _DemoMetric() tensorboard_log_dir = "data/tensorboard" tensorboard_log_dir = os.path.join(ROOT_PATH, tensorboard_log_dir) # shutil.rmtree(tensorboard_log_dir) trainer = Trainer(num_epoch=100, model=model, loss=loss, metrics=metric, optimizer_factory=optimizer_factory, serialize_dir=serialize_dir, patient=20, num_check_point_keep=25, device=device, trainer_callback=None, is_distributed=is_distributed ) logging.info(f"test is_distributed: {is_distributed}") # trainer_callback = BasicTrainerCallbackComposite(tensorboard_log_dir=tensorboard_log_dir) train_dataset = _DemoDataset() if is_distributed: sampler = DistributedSampler(dataset=train_dataset) else: sampler = None train_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, num_workers=0, sampler=sampler) if is_distributed: sampler = DistributedSampler(dataset=train_dataset) else: sampler = None validation_data_loader = DataLoader(dataset=train_dataset, collate_fn=_DemoCollate(), batch_size=200, num_workers=0, sampler=sampler) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) expect_model_state_dict = json.loads(json2str(trainer.model.state_dict())) expect_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict())) expect_current_epoch = trainer.current_epoch expect_num_epoch = trainer.num_epoch expect_metric = trainer.metrics.metric[0] expect_metric_tracker = json.loads(json2str(trainer.metric_tracker)) trainer.load_checkpoint(serialize_dir=serialize_dir) loaded_model_state_dict = json.loads(json2str(trainer.model.state_dict())) loaded_optimizer_state_dict = json.loads(json2str(trainer.optimizer.state_dict())) current_epoch = trainer.current_epoch num_epoch = trainer.num_epoch metric = trainer.metrics.metric[0] metric_tracker = json.loads(json2str(trainer.metric_tracker)) ASSERT.assertDictEqual(expect_model_state_dict, loaded_model_state_dict) ASSERT.assertDictEqual(expect_optimizer_state_dict, loaded_optimizer_state_dict) ASSERT.assertEqual(expect_current_epoch, current_epoch) ASSERT.assertEqual(expect_num_epoch, num_epoch) ASSERT.assertDictEqual(expect_metric, metric) ASSERT.assertDictEqual(expect_metric_tracker, metric_tracker)
def __call__(self, config: Dict, train_type: int): serialize_dir = config["serialize_dir"] vocabulary_dir = config["vocabulary_dir"] pretrained_embedding_file_path = config["pretrained_embedding_file_path"] word_embedding_dim = config["word_embedding_dim"] pretrained_embedding_max_size = config["pretrained_embedding_max_size"] is_fine_tuning = config["fine_tuning"] word_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "word_vocabulary") event_type_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "event_type_vocabulary") entity_tag_vocab_dir = os.path.join(vocabulary_dir, "vocabulary", "entity_tag_vocabulary") if train_type == Train.NEW_TRAIN: if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) if os.path.isdir(vocabulary_dir): shutil.rmtree(vocabulary_dir) os.makedirs(vocabulary_dir) os.makedirs(word_vocab_dir) os.makedirs(event_type_vocab_dir) os.makedirs(entity_tag_vocab_dir) elif train_type == Train.RECOVERY_TRAIN: pass else: assert False, f"train_type: {train_type} error!" train_dataset_file_path = config["train_dataset_file_path"] validation_dataset_file_path = config["validation_dataset_file_path"] num_epoch = config["epoch"] batch_size = config["batch_size"] if train_type == Train.NEW_TRAIN: # 构建词汇表 ace_dataset = ACEDataset(train_dataset_file_path) vocab_data_loader = DataLoader(dataset=ace_dataset, batch_size=10, shuffle=False, num_workers=0, collate_fn=EventVocabularyCollate()) tokens: List[List[str]] = list() event_types: List[List[str]] = list() entity_tags: List[List[str]] = list() for colleta_dict in vocab_data_loader: tokens.extend(colleta_dict["tokens"]) event_types.extend(colleta_dict["event_types"]) entity_tags.extend(colleta_dict["entity_tags"]) word_vocabulary = Vocabulary(tokens=tokens, padding=Vocabulary.PADDING, unk=Vocabulary.UNK, special_first=True) glove_loader = GloveLoader(embedding_dim=word_embedding_dim, pretrained_file_path=pretrained_embedding_file_path, max_size=pretrained_embedding_max_size) pretrained_word_vocabulary = PretrainedVocabulary(vocabulary=word_vocabulary, pretrained_word_embedding_loader=glove_loader) pretrained_word_vocabulary.save_to_file(word_vocab_dir) event_type_vocabulary = Vocabulary(tokens=event_types, padding="", unk="Negative", special_first=True) event_type_vocabulary.save_to_file(event_type_vocab_dir) entity_tag_vocabulary = LabelVocabulary(labels=entity_tags, padding=LabelVocabulary.PADDING) entity_tag_vocabulary.save_to_file(entity_tag_vocab_dir) else: pretrained_word_vocabulary = PretrainedVocabulary.from_file(word_vocab_dir) event_type_vocabulary = Vocabulary.from_file(event_type_vocab_dir) entity_tag_vocabulary = Vocabulary.from_file(entity_tag_vocab_dir) model = EventModel(alpha=0.5, activate_score=True, sentence_vocab=pretrained_word_vocabulary, sentence_embedding_dim=word_embedding_dim, entity_tag_vocab=entity_tag_vocabulary, entity_tag_embedding_dim=50, event_type_vocab=event_type_vocabulary, event_type_embedding_dim=300, lstm_hidden_size=300, lstm_encoder_num_layer=1, lstm_encoder_droupout=0.4) trainer = Trainer( serialize_dir=serialize_dir, num_epoch=num_epoch, model=model, loss=EventLoss(), optimizer_factory=EventOptimizerFactory(is_fine_tuning=is_fine_tuning), metrics=EventF1MetricAdapter(event_type_vocabulary=event_type_vocabulary), patient=10, num_check_point_keep=5, devices=None ) train_dataset = EventDataset(dataset_file_path=train_dataset_file_path, event_type_vocabulary=event_type_vocabulary) validation_dataset = EventDataset(dataset_file_path=validation_dataset_file_path, event_type_vocabulary=event_type_vocabulary) event_collate = EventCollate(word_vocabulary=pretrained_word_vocabulary, event_type_vocabulary=event_type_vocabulary, entity_tag_vocabulary=entity_tag_vocabulary, sentence_max_len=512) train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=0, collate_fn=event_collate) validation_data_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, num_workers=0, collate_fn=event_collate) if train_type == Train.NEW_TRAIN: trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader) else: trainer.recovery_train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader)
def __call__(self): config = self.config serialize_dir = config["serialize_dir"] if self._train_type == Train.NEW_TRAIN: # 清理 serialize dir if os.path.isdir(serialize_dir): shutil.rmtree(serialize_dir) os.makedirs(serialize_dir) num_epoch = config["num_epoch"] patient = config["patient"] num_check_point_keep = config["num_check_point_keep"] train_dataset = Conll2003Dataset( dataset_file_path=config["train_dataset_file_path"]) validation_dataset = Conll2003Dataset( dataset_file_path=config["validation_dataset_file_path"]) # 构建 vocabulary vocab_dict = self.build_vocabulary(dataset=train_dataset) token_vocabulary = vocab_dict["token_vocabulary"] label_vocabulary = vocab_dict["label_vocabulary"] model = self.build_model(token_vocabulary=token_vocabulary, label_vocabulary=label_vocabulary) loss = self.build_loss() metric = self.build_model_metric(label_vocabulary=label_vocabulary) cuda = config["cuda"] trainer = Trainer(serialize_dir=serialize_dir, num_epoch=num_epoch, model=model, loss=loss, metrics=metric, optimizer_factory=NerOptimizerFactory(), lr_scheduler_factory=None, patient=patient, num_check_point_keep=num_check_point_keep, cuda_devices=cuda) model_collate = NerModelCollate(token_vocab=token_vocabulary, sequence_label_vocab=label_vocabulary, sequence_max_len=512) train_data_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=model_collate) validation_data_loader = DataLoader(dataset=validation_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=model_collate) trainer.train(train_data_loader=train_data_loader, validation_data_loader=validation_data_loader)