Esempio n. 1
0
 def __init__(self,
              model,
              loss,
              optimizer,
              metrics=None,
              model_dir=None,
              bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from zoo.orca.learn.optimizers.schedule import Default
         optimizer = SGD(learningrate_schedule=Default())
     if isinstance(optimizer, TorchOptimizer):
         optimizer = TorchOptim.from_pytorch(optimizer)
     elif isinstance(optimizer, OrcaOptimizer):
         optimizer = optimizer.get_optimizer()
     else:
         raise ValueError(
             "Only PyTorch optimizer and orca optimizer are supported")
     from zoo.orca.learn.metrics import Metric
     self.metrics = Metric.convert_metrics_list(metrics)
     self.log_dir = None
     self.app_name = None
     self.model_dir = model_dir
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model,
                                     optimizer,
                                     model_dir,
                                     bigdl_type=bigdl_type)
Esempio n. 2
0
    def test_torch_optim(self):
        class SimpleTorchModel(nn.Module):
            def __init__(self):
                super(SimpleTorchModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.bn1 = torch.nn.BatchNorm1d(4)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = self.bn1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        self.sc.stop()
        self.sc = init_nncontext()
        torch_model = SimpleTorchModel()
        loss_fn = torch.nn.BCELoss()
        torch_optim = torch.optim.Adam(torch_model.parameters())
        az_model = TorchModel.from_pytorch(torch_model)
        zoo_loss = TorchLoss.from_pytorch(loss_fn)

        def train_dataloader():
            inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9],
                                   [1, 9]])
            targets = torch.Tensor([[0], [0], [0], [1], [1], [1]])
            return DataLoader(TensorDataset(inputs, targets), batch_size=2)

        train_featureset = FeatureSet.pytorch_dataloader(train_dataloader)
        val_featureset = FeatureSet.pytorch_dataloader(train_dataloader)
        zoo_optimizer = TorchOptim.from_pytorch(torch_optim)
        estimator = Estimator(az_model, optim_methods=zoo_optimizer)
        estimator.train_minibatch(train_featureset,
                                  zoo_loss,
                                  end_trigger=MaxEpoch(4),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=val_featureset,
                                  validation_method=[Accuracy()])

        trained_model = az_model.to_pytorch()
Esempio n. 3
0
 def __init__(self,
              model,
              loss,
              optimizer,
              model_dir=None,
              bigdl_type="float"):
     from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
     self.loss = loss
     if self.loss is None:
         self.loss = TorchLoss()
     else:
         self.loss = TorchLoss.from_pytorch(loss)
     if optimizer is None:
         from bigdl.optim.optimizer import SGD
         optimizer = SGD()
     elif isinstance(optimizer, TorchOptimizer):
         optimizer = TorchOptim.from_pytorch(optimizer)
     self.model = TorchModel.from_pytorch(model)
     self.estimator = SparkEstimator(self.model,
                                     optimizer,
                                     model_dir,
                                     bigdl_type=bigdl_type)
Esempio n. 4
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dir',
                        default='/tmp/data',
                        metavar='N',
                        help='the folder store mnist data')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=256,
        metavar='N',
        help='input batch size for training per executor(default: 256)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=1000,
        metavar='N',
        help='input batch size for testing per executor(default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 2)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False)

    # init on yarn when HADOOP_CONF_DIR and ZOO_CONDA_NAME is provided.
    if os.environ.get('HADOOP_CONF_DIR') is None:
        sc = init_spark_on_local(cores=1, conf={"spark.driver.memory": "20g"})
    else:
        num_executors = 2
        num_cores_per_executor = 4
        hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')
        zoo_conda_name = os.environ.get(
            'ZOO_CONDA_NAME')  # The name of the created conda-env
        sc = init_spark_on_yarn(hadoop_conf=hadoop_conf_dir,
                                conda_name=zoo_conda_name,
                                num_executors=num_executors,
                                executor_cores=num_cores_per_executor,
                                executor_memory="2g",
                                driver_memory="10g",
                                driver_cores=1,
                                conf={
                                    "spark.rpc.message.maxSize":
                                    "1024",
                                    "spark.task.maxFailures":
                                    "1",
                                    "spark.driver.extraJavaOptions":
                                    "-Dbigdl.failure.retryTimes=1"
                                })

    model = Net()
    model.train()
    criterion = nn.NLLLoss()

    adam = torch.optim.Adam(model.parameters(), lr=args.lr)
    zoo_model = TorchModel.from_pytorch(model)
    zoo_criterion = TorchLoss.from_pytorch(criterion)
    zoo_optim = TorchOptim.from_pytorch(adam)
    zoo_estimator = Estimator(zoo_model, optim_methods=zoo_optim)
    train_featureset = FeatureSet.pytorch_dataloader(train_loader)
    test_featureset = FeatureSet.pytorch_dataloader(test_loader)
    from bigdl.optim.optimizer import MaxEpoch, EveryEpoch
    zoo_estimator.train_minibatch(train_featureset,
                                  zoo_criterion,
                                  end_trigger=MaxEpoch(args.epochs),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=test_featureset,
                                  validation_method=[Accuracy()])