Esempio n. 1
0
    def compile(self, optimizer, loss, metrics=None):
        """
        Configure the learning process. It MUST be called before fit or evaluate.

        # Arguments
        optimizer: Optimization method to be used. One can alternatively pass in the corresponding
                   string representation, such as 'sgd'.
        loss: Criterion to be used. One can alternatively pass in the corresponding string
              representation, such as 'mse'.
        metrics: List of validation methods to be used. Default is None if no validation is needed.
                 For convenience, string representations are supported: 'accuracy' (or 'acc'),
                 'top5accuracy' (or 'top5acc'), 'mae', 'auc', 'treennaccuracy' and 'loss'.
                 For example, you can either use [Accuracy()] or ['accuracy'].
        """
        if isinstance(optimizer, six.string_types):
            optimizer = to_bigdl_optim_method(optimizer)
        criterion = loss
        if isinstance(loss, six.string_types):
            criterion = to_bigdl_criterion(loss)
        if callable(loss):
            from zoo.pipeline.api.autograd import CustomLoss
            criterion = CustomLoss(loss, self.get_output_shape()[1:])
        if metrics and all(isinstance(metric, six.string_types) for metric in metrics):
            metrics = to_bigdl_metrics(metrics, loss)
        callBigDlFunc(self.bigdl_type, "zooCompile",
                      self.value,
                      optimizer,
                      criterion,
                      metrics)
Esempio n. 2
0
    def compile(self, optimizer, loss, metrics=None):
        """
        Configure the learning process. It MUST be called before fit or evaluate.

        # Arguments
        optimizer: Optimization method to be used. One can alternatively pass in the corresponding
                   string representation, such as 'sgd'.
        loss: Criterion to be used. One can alternatively pass in the corresponding string
              representation, such as 'mse'.
        metrics: List of validation methods to be used. Default is None if no validation is needed.
                 One can alternatively use ['accuracy'].
        """
        if isinstance(optimizer, six.string_types):
            optimizer = to_bigdl_optim_method(optimizer)
        if isinstance(loss, six.string_types):
            loss = to_bigdl_criterion(loss)
        if callable(loss):
            from zoo.pipeline.api.autograd import CustomLoss
            loss = CustomLoss(loss, self.get_output_shape()[1:])
        if metrics and all(isinstance(metric, six.string_types) for metric in metrics):
            metrics = to_bigdl_metrics(metrics)
        callBigDlFunc(self.bigdl_type, "zooCompile",
                      self.value,
                      optimizer,
                      loss,
                      metrics)