def run4(self): set_rng_seed(100) data_set, model = prepare_env() train_set, dev_set = data_set.split(0.3) model = NaiveClassifier(2, 1) trainer = DistTrainer( train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=self.save_path, ) trainer.train() """ # 应该正确运行 """ if trainer.is_master and os.path.exists(self.save_path): shutil.rmtree(self.save_path)
def run1(self): # test distributed training print('local rank', get_local_rank()) set_rng_seed(100) data_set = prepare_fake_dataset() data_set.set_input("x", flag=True) data_set.set_target("y", flag=True) model = NaiveClassifier(2, 2) trainer = DistTrainer( model=model, train_data=data_set, optimizer=SGD(lr=0.1), loss=CrossEntropyLoss(pred="predict", target="y"), batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, ) trainer.train() """ # 应该正确运行 """ if trainer.is_master and os.path.exists(self.save_path): shutil.rmtree(self.save_path)
def run3(self): set_rng_seed(100) data_set, model = prepare_env() trainer = DistTrainer( data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), n_epochs=3, print_every=50, callbacks_all=[EchoCallback('callbacks_all')], callbacks_master=[EchoCallback('callbacks_master')]) trainer.train()