Exemplo n.º 1
0
    def fit(self, data, epochs=1, batch_size=32, validation_data=None, validation_methods=None,
            checkpoint_trigger=None):
        from zoo.orca.data.utils import to_sample

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"

        if isinstance(data, SparkXShards):
            train_rdd = data.rdd.flatMap(to_sample)
            train_feature_set = FeatureSet.sample_rdd(train_rdd)
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                                  "SparkXShards"
                val_feature_set = FeatureSet.sample_rdd(validation_data.rdd.flatMap(to_sample))

            self.estimator.train(train_feature_set, self.loss, end_trigger, checkpoint_trigger,
                                 val_feature_set, validation_methods, batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, DataLoader) or callable(data), \
                    "validation_data should be a pytorch DataLoader or a callable data_creator"
                val_feature_set = FeatureSet.pytorch_dataloader(validation_data)

            self.estimator.train_minibatch(train_feature_set, self.loss, end_trigger,
                                           checkpoint_trigger, val_feature_set, validation_methods)
        else:
            raise ValueError("Data and validation data should be SparkXShards, DataLoaders or "
                             "callable data_creators but get " + data.__class__.__name__)
        return self
Exemplo n.º 2
0
    def test_train_model_with_bn(self):
        class SimpleTorchModel(nn.Module):
            def __init__(self):
                super(SimpleTorchModel, self).__init__()
                self.dense1 = nn.Linear(2, 4)
                self.bn1 = torch.nn.BatchNorm1d(4)
                self.dense2 = nn.Linear(4, 1)

            def forward(self, x):
                x = self.dense1(x)
                x = self.bn1(x)
                x = torch.sigmoid(self.dense2(x))
                return x

        torch_model = SimpleTorchModel()
        loss_fn = torch.nn.BCELoss()
        az_model = TorchModel.from_pytorch(torch_model)
        zoo_loss = TorchLoss.from_pytorch(loss_fn)
        inputs = torch.Tensor([[1, 2], [1, 3], [3, 2],
                               [5, 6], [8, 9], [1, 9]])
        targets = torch.Tensor([[0], [0], [0],
                               [1], [1], [1]])
        train_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2)
        train_featureset = FeatureSet.pytorch_dataloader(train_loader)
        val_loader = DataLoader(TensorDataset(inputs, targets), batch_size=2)
        val_featureset = FeatureSet.pytorch_dataloader(val_loader)

        zooOptimizer = Adam()
        estimator = Estimator(az_model, optim_methods=zooOptimizer)
        estimator.train_minibatch(train_featureset, zoo_loss, end_trigger=MaxEpoch(4),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=val_featureset,
                                  validation_method=[Accuracy()])

        trained_model = az_model.to_pytorch()
Exemplo n.º 3
0
    def _hanle_data_loader(self, data, validation_data):
        train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
        if validation_data is None:
            val_feature_set = None
        else:
            assert isinstance(validation_data, DataLoader) or callable(data), \
                "validation_data should be a pytorch DataLoader or a callable data_creator"
            val_feature_set = FeatureSet.pytorch_dataloader(validation_data)

        return train_feature_set, val_feature_set
Exemplo n.º 4
0
    def fit(self,
            data,
            epochs=1,
            batch_size=32,
            feature_cols=None,
            labels_cols=None,
            validation_data=None,
            validation_metrics=None,
            checkpoint_trigger=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics
        from zoo.orca.learn.trigger import Trigger

        end_trigger = MaxEpoch(epochs)
        assert batch_size > 0, "batch_size should be greater than 0"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)
        checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger)

        if self.log_dir is not None and self.app_name is not None:
            self.estimator.set_tensorboard(self.log_dir, self.app_name)

        if isinstance(data, SparkXShards):
            train_rdd = data.rdd.flatMap(to_sample)
            train_feature_set = FeatureSet.sample_rdd(train_rdd)
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, SparkXShards), "validation_data should be a " \
                                                                  "SparkXShards"
                val_feature_set = FeatureSet.sample_rdd(
                    validation_data.rdd.flatMap(to_sample))

            self.estimator.train(train_feature_set, self.loss, end_trigger,
                                 checkpoint_trigger, val_feature_set,
                                 validation_metrics, batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            train_feature_set = FeatureSet.pytorch_dataloader(data, "", "")
            if validation_data is None:
                val_feature_set = None
            else:
                assert isinstance(validation_data, DataLoader) or callable(data), \
                    "validation_data should be a pytorch DataLoader or a callable data_creator"
                val_feature_set = FeatureSet.pytorch_dataloader(
                    validation_data)

            self.estimator.train_minibatch(train_feature_set, self.loss,
                                           end_trigger, checkpoint_trigger,
                                           val_feature_set, validation_metrics)
        else:
            raise ValueError(
                "Data and validation data should be SparkXShards, DataLoaders or "
                "callable data_creators but get " + data.__class__.__name__)
        return self
Exemplo n.º 5
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 label_cols=None):
        from zoo.orca.data.utils import xshard_to_sample

        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, SparkXShards):
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataFrame):
            schema = data.schema
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.map(lambda row: row_to_sample(
                    row, schema, feature_cols, label_cols)))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            result = self.estimator.evaluate_minibatch(val_feature_set,
                                                       self.metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
        return bigdl_metric_results_to_dict(result)
Exemplo n.º 6
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 labels_cols=None,
                 validation_metrics=None):
        from zoo.orca.data.utils import to_sample
        from zoo.orca.learn.metrics import Metrics

        assert data is not None, "validation data shouldn't be None"
        validation_metrics = Metrics.convert_metrics_list(validation_metrics)

        if isinstance(data, SparkXShards):
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(to_sample))
            return self.estimator.evaluate(val_feature_set, validation_metrics,
                                           batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            return self.estimator.evaluate_minibatch(val_feature_set,
                                                     validation_metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
Exemplo n.º 7
0
    def evaluate(self,
                 data,
                 batch_size=32,
                 feature_cols=None,
                 label_cols=None,
                 validation_metrics=None):
        """
        Evaluate model.

        :param data: data: evaluation data. It can be an XShards, Spark Dataframe,
               PyTorch DataLoader and PyTorch DataLoader creator function.
               If data is an XShards, each partition can be a Pandas DataFrame or a dictionary of
               {'x': feature, 'y': label}, where feature(label) is a numpy array or a list of
               numpy arrays.
        :param batch_size: Batch size used for evaluation. Only used when data is a SparkXShard.
        :param feature_cols: Feature column name(s) of data. Only used when data
               is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param label_cols: Label column name(s) of data. Only used when data is
               a Spark DataFrame or an XShards of Pandas DataFrame. Default: None.
        :param validation_metrics: Orca validation metrics to be computed on validation_data.
        :return: validation results.
        """
        from zoo.orca.data.utils import xshard_to_sample

        assert data is not None, "validation data shouldn't be None"
        assert self.metrics is not None, "metrics shouldn't be None, please specify the metrics" \
                                         " argument when creating this estimator."

        if isinstance(data, SparkXShards):
            if data._get_class_name() == 'pandas.core.frame.DataFrame':
                data = process_xshards_of_pandas_dataframe(
                    data, feature_cols, label_cols)
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.flatMap(xshard_to_sample))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataFrame):
            schema = data.schema
            val_feature_set = FeatureSet.sample_rdd(
                data.rdd.map(lambda row: row_to_sample(
                    row, schema, feature_cols, label_cols)))
            result = self.estimator.evaluate(val_feature_set, self.metrics,
                                             batch_size)
        elif isinstance(data, DataLoader) or callable(data):
            val_feature_set = FeatureSet.pytorch_dataloader(data)
            result = self.estimator.evaluate_minibatch(val_feature_set,
                                                       self.metrics)
        else:
            raise ValueError(
                "Data should be a SparkXShards, a DataLoader or a callable "
                "data_creator, but get " + data.__class__.__name__)
        return bigdl_metric_results_to_dict(result)
Exemplo n.º 8
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--dir',
                        default='/tmp/data',
                        metavar='N',
                        help='the folder store mnist data')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=256,
        metavar='N',
        help='input batch size for training per executor(default: 256)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=1000,
        metavar='N',
        help='input batch size for testing per executor(default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 2)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False)

    # init on yarn when HADOOP_CONF_DIR and ZOO_CONDA_NAME is provided.
    if os.environ.get('HADOOP_CONF_DIR') is None:
        sc = init_spark_on_local(cores=1, conf={"spark.driver.memory": "20g"})
    else:
        num_executors = 2
        num_cores_per_executor = 4
        hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')
        zoo_conda_name = os.environ.get(
            'ZOO_CONDA_NAME')  # The name of the created conda-env
        sc = init_spark_on_yarn(hadoop_conf=hadoop_conf_dir,
                                conda_name=zoo_conda_name,
                                num_executor=num_executors,
                                executor_cores=num_cores_per_executor,
                                executor_memory="2g",
                                driver_memory="10g",
                                driver_cores=1,
                                spark_conf={
                                    "spark.rpc.message.maxSize":
                                    "1024",
                                    "spark.task.maxFailures":
                                    "1",
                                    "spark.driver.extraJavaOptions":
                                    "-Dbigdl.failure.retryTimes=1"
                                })

    model = Net()
    model.train()
    criterion = nn.NLLLoss()

    adam = Adam(args.lr)
    zoo_model = TorchModel.from_pytorch(model)
    zoo_criterion = TorchLoss.from_pytorch(criterion)
    zoo_estimator = Estimator(zoo_model, optim_methods=adam)
    train_featureset = FeatureSet.pytorch_dataloader(train_loader)
    test_featureset = FeatureSet.pytorch_dataloader(test_loader)
    from bigdl.optim.optimizer import MaxEpoch, EveryEpoch
    zoo_estimator.train_minibatch(train_featureset,
                                  zoo_criterion,
                                  end_trigger=MaxEpoch(args.epochs),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=test_featureset,
                                  validation_method=[Accuracy()])
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('-a',
                        '--arch',
                        metavar='ARCH',
                        default='resnet18',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('--epochs',
                        default=90,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument(
        '-b',
        '--batch-size',
        default=256,
        type=int,
        metavar='N',
        help='mini-batch size (default: 256), this is the total '
        'batch size of all GPUs on the current node when '
        'using Data Parallel or Distributed Data Parallel')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=0.1,
                        type=float,
                        metavar='LR',
                        help='initial learning rate',
                        dest='lr')
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        metavar='M',
                        help='momentum')
    parser.add_argument('--wd',
                        '--weight-decay',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p',
                        '--print-freq',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained',
                        dest='pretrained',
                        action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--world-size',
                        default=-1,
                        type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank',
                        default=-1,
                        type=int,
                        help='node rank for distributed training')
    parser.add_argument('--seed',
                        default=None,
                        type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--cores',
                        default=4,
                        type=int,
                        help='num of CPUs to use.')
    parser.add_argument('--nodes',
                        default=1,
                        type=int,
                        help='num of nodes to use.')
    parser.add_argument('--executor_memory',
                        default='20g',
                        type=str,
                        help='size of executor memory.')
    parser.add_argument('--driver_memory',
                        default='20g',
                        type=str,
                        help='size of driver memory.')
    parser.add_argument('--driver_cores',
                        default=1,
                        type=int,
                        help='num of driver cores to use.')
    args = parser.parse_args()
    # sc = init_nncontext()
    if os.environ.get('HADOOP_CONF_DIR') is None:
        sc = init_spark_on_local(cores=args.cores,
                                 conf={"spark.driver.memory": "20g"})
    else:
        hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')
        num_executors = args.nodes
        executor_memory = args.executor_memory
        driver_memory = args.driver_memory
        driver_cores = args.driver_cores
        num_cores_per_executor = args.cores
        os.environ['ZOO_MKL_NUMTHREADS'] = str(num_cores_per_executor)
        os.environ['OMP_NUM_THREADS'] = str(num_cores_per_executor)
        sc = init_spark_on_yarn(
            hadoop_conf=hadoop_conf_dir,
            conda_name=detect_python_location().split("/")
            [-3],  # The name of the created conda-env
            num_executors=num_executors,
            executor_cores=num_cores_per_executor,
            executor_memory=executor_memory,
            driver_memory=driver_memory,
            driver_cores=driver_cores,
            conf={
                "spark.rpc.message.maxSize": "1024",
                "spark.task.maxFailures": "1",
                "spark.driver.extraJavaOptions": "-Dbigdl.failure.retryTimes=1"
            })

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    model = torchvision.models.resnet50()
    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False)

    iterationPerEpoch = int(math.ceil(float(1281167) / args.batch_size))
    step = Step(iterationPerEpoch * 30, 0.1)
    zooOptimizer = SGD(args.lr,
                       momentum=args.momentum,
                       dampening=0.0,
                       leaningrate_schedule=step,
                       weightdecay=args.weight_decay)
    zooModel = TorchModel.from_pytorch(model)
    criterion = torch.nn.CrossEntropyLoss()
    zooCriterion = TorchLoss.from_pytorch(criterion)
    estimator = Estimator(zooModel, optim_methods=zooOptimizer)
    train_featureSet = FeatureSet.pytorch_dataloader(train_loader)
    test_featureSet = FeatureSet.pytorch_dataloader(val_loader)
    estimator.train_minibatch(train_featureSet,
                              zooCriterion,
                              end_trigger=MaxEpoch(90),
                              checkpoint_trigger=EveryEpoch(),
                              validation_set=test_featureSet,
                              validation_method=[Accuracy(),
                                                 Top5Accuracy()])
Exemplo n.º 10
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example')
    parser.add_argument('--dir',
                        default='/tmp/data',
                        metavar='N',
                        help='the folder store cifar10 data')
    parser.add_argument(
        '--batch-size',
        type=int,
        default=128,
        metavar='N',
        help='input batch size for training per executor(default: 128)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=1000,
        metavar='N',
        help='input batch size for testing per executor(default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=135,
                        metavar='N',
                        help='number of epochs to train (default: 135)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--lrd',
                        type=float,
                        default=0.0,
                        metavar='LRD',
                        help='learning rate decay(default: 0.0)')
    parser.add_argument('--wd',
                        type=float,
                        default=5e-4,
                        metavar='WD',
                        help='weight decay(default: 5e-4)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='momentum',
                        help='momentum (default: 0.9)')
    parser.add_argument('--dampening',
                        type=float,
                        default=0.0,
                        metavar='dampening',
                        help='dampening (default: 0.0)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    # 准备数据并预处理
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
        transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),  # R,G,B每层的归一化用到的均值和方差
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    train_set = datasets.CIFAR10(args.dir,
                                 train=True,
                                 download=True,
                                 transform=transform_train)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2)

    test_set = datasets.CIFAR10(args.dir,
                                train=False,
                                transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              num_workers=2)

    # init on yarn when HADOOP_CONF_DIR and ZOO_CONDA_NAME is provided.
    if os.environ.get('HADOOP_CONF_DIR') is None:
        sc = init_spark_on_local(cores=1, conf={"spark.driver.memory": "20g"})
    else:
        num_executors = 2
        num_cores_per_executor = 4
        hadoop_conf_dir = os.environ.get('HADOOP_CONF_DIR')
        zoo_conda_name = os.environ.get(
            'ZOO_CONDA_NAME')  # The name of the created conda-env
        sc = init_spark_on_yarn(hadoop_conf=hadoop_conf_dir,
                                conda_name=zoo_conda_name,
                                num_executor=num_executors,
                                executor_cores=num_cores_per_executor,
                                executor_memory="2g",
                                driver_memory="10g",
                                driver_cores=1,
                                spark_conf={
                                    "spark.rpc.message.maxSize":
                                    "1024",
                                    "spark.task.maxFailures":
                                    "1",
                                    "spark.driver.extraJavaOptions":
                                    "-Dbigdl.failure.retryTimes=1"
                                })

    model = ResNet18()
    model.train()
    criterion = nn.CrossEntropyLoss()

    optimizer = SGD(args.lr, args.lrd, args.wd, args.momentum, args.dampening)
    zoo_model = TorchModel.from_pytorch(model)
    zoo_criterion = TorchLoss.from_pytorch(criterion)
    zoo_estimator = Estimator(zoo_model, optim_methods=optimizer)
    train_featureset = FeatureSet.pytorch_dataloader(train_loader)
    test_featureset = FeatureSet.pytorch_dataloader(test_loader)
    from bigdl.optim.optimizer import MaxEpoch, EveryEpoch
    zoo_estimator.train_minibatch(train_featureset,
                                  zoo_criterion,
                                  end_trigger=MaxEpoch(args.epochs),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=test_featureset,
                                  validation_method=[Accuracy()])
Exemplo n.º 11
0
                                })

    iter_per_epoch = len(cifar100_training_loader)
    warmup_delta = args.lr / (iter_per_epoch * args.warm)
    # iteration_per_epoch = int(math.ceil(float(len(cifar100_training_loader)) / args.b))
    zoo_lrSchedule = SequentialSchedule(iter_per_epoch)
    zoo_lrSchedule.add(Warmup(warmup_delta), iter_per_epoch * args.warm)
    zoo_lrSchedule.add(
        MultiStep(
            [iter_per_epoch * 60, iter_per_epoch * 120, iter_per_epoch * 160],
            0.2), iter_per_epoch * 200)
    zoo_optim = SGD(learningrate=0.0,
                    learningrate_decay=0.0,
                    weightdecay=5e-4,
                    momentum=0.9,
                    dampening=0.0,
                    nesterov=False,
                    leaningrate_schedule=zoo_lrSchedule)

    zoo_model = TorchModel.from_pytorch(net)
    zoo_loss = TorchLoss.from_pytorch(loss_function)
    zoo_estimator = Estimator(zoo_model, optim_methods=zoo_optim)
    train_featureset = FeatureSet.pytorch_dataloader(cifar100_training_loader)
    test_featureset = FeatureSet.pytorch_dataloader(cifar100_test_loader)
    zoo_estimator.train_minibatch(train_featureset,
                                  zoo_loss,
                                  end_trigger=MaxEpoch(settings.EPOCH),
                                  checkpoint_trigger=EveryEpoch(),
                                  validation_set=test_featureset,
                                  validation_method=[Accuracy()])