def compile(self, optimizer, loss, metrics=None): """ Configure the learning process. It MUST be called before fit or evaluate. # Arguments optimizer: Optimization method to be used. One can alternatively pass in the corresponding string representation, such as 'sgd'. loss: Criterion to be used. One can alternatively pass in the corresponding string representation, such as 'mse'. metrics: List of validation methods to be used. Default is None if no validation is needed. For convenience, string representations are supported: 'accuracy' (or 'acc'), 'top5accuracy' (or 'top5acc'), 'mae', 'auc', 'treennaccuracy' and 'loss'. For example, you can either use [Accuracy()] or ['accuracy']. """ if isinstance(optimizer, six.string_types): optimizer = to_bigdl_optim_method(optimizer) criterion = loss if isinstance(loss, six.string_types): criterion = to_bigdl_criterion(loss) if callable(loss): from zoo.pipeline.api.autograd import CustomLoss criterion = CustomLoss(loss, self.get_output_shape()[1:]) if metrics and all(isinstance(metric, six.string_types) for metric in metrics): metrics = to_bigdl_metrics(metrics, loss) callBigDlFunc(self.bigdl_type, "zooCompile", self.value, optimizer, criterion, metrics)
def compile(self, optimizer, loss, metrics=None): """ Configure the learning process. It MUST be called before fit or evaluate. # Arguments optimizer: Optimization method to be used. One can alternatively pass in the corresponding string representation, such as 'sgd'. loss: Criterion to be used. One can alternatively pass in the corresponding string representation, such as 'mse'. metrics: List of validation methods to be used. Default is None if no validation is needed. One can alternatively use ['accuracy']. """ if isinstance(optimizer, six.string_types): optimizer = to_bigdl_optim_method(optimizer) if isinstance(loss, six.string_types): loss = to_bigdl_criterion(loss) if callable(loss): from zoo.pipeline.api.autograd import CustomLoss loss = CustomLoss(loss, self.get_output_shape()[1:]) if metrics and all(isinstance(metric, six.string_types) for metric in metrics): metrics = to_bigdl_metrics(metrics) callBigDlFunc(self.bigdl_type, "zooCompile", self.value, optimizer, loss, metrics)