def run4(self):
        set_rng_seed(100)
        data_set, model = prepare_env()

        train_set, dev_set = data_set.split(0.3)

        model = NaiveClassifier(2, 1)

        trainer = DistTrainer(
            train_set,
            model,
            optimizer=SGD(lr=0.1),
            loss=BCELoss(pred="predict", target="y"),
            batch_size_per_gpu=32,
            n_epochs=3,
            print_every=50,
            dev_data=dev_set,
            metrics=AccuracyMetric(pred="predict", target="y"),
            validate_every=-1,
            save_path=self.save_path,
        )
        trainer.train()
        """
        # 应该正确运行
        """
        if trainer.is_master and os.path.exists(self.save_path):
            shutil.rmtree(self.save_path)
    def run1(self):
        # test distributed training
        print('local rank', get_local_rank())
        set_rng_seed(100)
        data_set = prepare_fake_dataset()
        data_set.set_input("x", flag=True)
        data_set.set_target("y", flag=True)

        model = NaiveClassifier(2, 2)

        trainer = DistTrainer(
            model=model,
            train_data=data_set,
            optimizer=SGD(lr=0.1),
            loss=CrossEntropyLoss(pred="predict", target="y"),
            batch_size_per_gpu=8,
            n_epochs=3,
            print_every=50,
            save_path=self.save_path,
        )
        trainer.train()
        """
        # 应该正确运行
        """
        if trainer.is_master and os.path.exists(self.save_path):
            shutil.rmtree(self.save_path)
 def run3(self):
     set_rng_seed(100)
     data_set, model = prepare_env()
     trainer = DistTrainer(
         data_set,
         model,
         optimizer=None,
         loss=BCELoss(pred="predict", target="y"),
         n_epochs=3,
         print_every=50,
         callbacks_all=[EchoCallback('callbacks_all')],
         callbacks_master=[EchoCallback('callbacks_master')])
     trainer.train()