def test_legacy_calculate_loss_without_sample_weight(self): labels = torch.tensor([[1.0, 2.0, 3.0]]) outputs = torch.tensor([[1.0, 0.0, 2.0]]) def fn_minus(output, label, reduction=None): losses = label - output if reduction == 'none': return losses else: return losses.mean() def fn_add(output, label, reduction=None): losses = label + output if reduction == 'none': return losses else: return losses.mean() kwargs = dict(model=Mock(), optimizer=Mock(), feature_cols=[], sample_weights_col=None, validation=0) model = to_lightning_module(loss_fns=[fn_minus], loss_weights=[1], label_cols=['a'], **kwargs) loss = model._calculate_loss(outputs, labels) assert loss == 1.0 labels = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0]]) outputs = torch.tensor([[1.0, 0.0, 2.0], [0.0, 0.0, 2.0]]) model = to_lightning_module(loss_fns=[fn_minus, fn_add], loss_weights=[0.2, 0.8], label_cols=['a', 'b'], **kwargs) loss = model._calculate_loss(outputs, labels) assert torch.isclose(loss, torch.tensor(2.6))
def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row_size, dataset_idx=None): self._check_params(metadata) run_id = self.getRunId() if run_id is None: run_id = 'pytorch_' + str(int(time.time())) model = self.getModel() is_legacy = not isinstance(model, LightningModule) if is_legacy: # Legacy: convert params to LightningModule model = to_lightning_module(model=self.getModel(), optimizer=self._get_optimizer(), loss_fns=self.getLoss(), loss_weights=self.getLossWeights(), feature_cols=self.getFeatureCols(), label_cols=self.getLabelCols(), sample_weights_col=self.getSampleWeightCol(), validation=self.getValidation()) serialized_model = serialize_fn()(model) # FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator. ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None trainer = remote.RemoteTrainer(self, metadata=metadata, ckpt_bytes=ckpt_bytes, run_id=run_id, dataset_idx=dataset_idx, train_rows=train_rows, val_rows=val_rows, avg_row_size=avg_row_size, is_legacy=is_legacy) handle = backend.run(trainer, args=(serialized_model,), env={}) return self._create_model(handle, run_id, metadata)