コード例 #1
0
ファイル: run.py プロジェクト: sbuschjaeger/gncl
                models.append(
                    {
                        "model":GNCLClassifier,
                        "n_estimators":m,
                        "mode":"upper",
                        "l_reg":l_reg,
                        "combination_type":"average",
                        "base_estimator": partial(simpleresnet, size=s, model_type=t),
                        "optimizer":optimizer,
                        "scheduler":scheduler,
                        "loader":loader,
                        "eval_every":5,
                        "store_every":0,
                        "loss_function":nn.CrossEntropyLoss(reduction="none"),
                        "use_amp":True,
                        "device":"cuda",
                        "train_data": torchvision.datasets.FashionMNIST(".", train=True, transform = train_transformation()),
                        "test_data": torchvision.datasets.FashionMNIST(".", train=False, transform = test_transformation()),
                        "verbose":True
                    }
                )

try:
    base = models[0]["base_estimator"]().cuda()
    rnd_input = torch.rand((1, 1, 28, 28)).cuda()
    print(summary(base, rnd_input))
except:
    pass

run_experiments(basecfg, models)
コード例 #2
0
                    "sigma":s,
                    "scale":1,
                    "epsilon":e
                })
            )

            runs.append(
                ( {   
                    "method": "SieveStreaming++",
                    "K":K,
                    "sigma":s,
                    "scale":1,
                    "epsilon":e
                })
            )

            for T in Ts:    
                runs.append(
                    ( {   
                        "method": "ThreeSieves",
                        "K":K,
                        "sigma":s,
                        "scale":1,
                        "epsilon":e,
                        "T":T
                    })
                )

# random.shuffle(runs)
run_experiments(basecfg, runs)