コード例 #1
0
    def test_legacy_calculate_loss_without_sample_weight(self):
        labels = torch.tensor([[1.0, 2.0, 3.0]])
        outputs = torch.tensor([[1.0, 0.0, 2.0]])

        def fn_minus(output, label, reduction=None):
            losses = label - output
            if reduction == 'none':
                return losses
            else:
                return losses.mean()

        def fn_add(output, label, reduction=None):
            losses = label + output
            if reduction == 'none':
                return losses
            else:
                return losses.mean()

        kwargs = dict(model=Mock(), optimizer=Mock(), feature_cols=[], sample_weights_col=None, validation=0)
        model = to_lightning_module(loss_fns=[fn_minus], loss_weights=[1], label_cols=['a'],  **kwargs)
        loss = model._calculate_loss(outputs, labels)
        assert loss == 1.0

        labels = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 4.0]])
        outputs = torch.tensor([[1.0, 0.0, 2.0], [0.0, 0.0, 2.0]])

        model = to_lightning_module(loss_fns=[fn_minus, fn_add], loss_weights=[0.2, 0.8], label_cols=['a', 'b'], **kwargs)
        loss = model._calculate_loss(outputs, labels)
        assert torch.isclose(loss, torch.tensor(2.6))
コード例 #2
0
ファイル: estimator.py プロジェクト: lakersdf/horovod
    def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row_size, dataset_idx=None):
        self._check_params(metadata)

        run_id = self.getRunId()
        if run_id is None:
            run_id = 'pytorch_' + str(int(time.time()))

        model = self.getModel()
        is_legacy = not isinstance(model, LightningModule)
        if is_legacy:
            # Legacy: convert params to LightningModule
            model = to_lightning_module(model=self.getModel(),
                                        optimizer=self._get_optimizer(),
                                        loss_fns=self.getLoss(),
                                        loss_weights=self.getLossWeights(),
                                        feature_cols=self.getFeatureCols(),
                                        label_cols=self.getLabelCols(),
                                        sample_weights_col=self.getSampleWeightCol(),
                                        validation=self.getValidation())

        serialized_model = serialize_fn()(model)
        # FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator.
        ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None
        trainer = remote.RemoteTrainer(self,
                                       metadata=metadata,
                                       ckpt_bytes=ckpt_bytes,
                                       run_id=run_id,
                                       dataset_idx=dataset_idx,
                                       train_rows=train_rows,
                                       val_rows=val_rows,
                                       avg_row_size=avg_row_size,
                                       is_legacy=is_legacy)
        handle = backend.run(trainer, args=(serialized_model,), env={})
        return self._create_model(handle, run_id, metadata)