def test_soft_update_not_work():

    model = nn.ModuleDict(
        {
            "target": nn.Linear(10, 10, bias=False),
            "source": nn.Linear(10, 10, bias=False),
        }
    )
    set_requires_grad(model, False)
    model["target"].weight.data.fill_(0)

    runner = DummyRunner(model=model)
    runner.is_train_loader = True

    soft_update = dl.SoftUpdateCallaback(
        target_model="target",
        source_model="source",
        tau=0.1,
        scope="on_batch_start",
    )
    soft_update.on_batch_end(runner)

    checks = (runner.model["target"].weight.data == 0).flatten().tolist()

    assert all(checks)
Exemple #2
0
    # define criterion
    criterion = NTXentLoss(tau=args.temperature)

    # and callbacks
    callbacks = [
        dl.CriterionCallback(
            input_key="online_projection_left",
            target_key="target_projection_right",
            metric_key="loss",
        ),
        dl.BackwardCallback(metric_key="loss"),
        dl.OptimizerCallback(metric_key="loss"),
        dl.ControlFlowCallbackWrapper(
            dl.SoftUpdateCallaback(
                target_model="target",
                source_model="online",
                tau=0.1,
                scope="on_batch_end",
            ),
            loaders="train",
        ),
        dl.SklearnModelCallback(
            feature_key="online_embedding_origin",
            target_key="target",
            train_loader="train",
            valid_loaders="valid",
            model_fn=LogisticRegression,
            predict_key="sklearn_predict",
            predict_method="predict_proba",
            C=0.1,
            solver="saga",
            max_iter=200,