Exemple #1
0
 def evaluate_and_log(st_x: StX, st_it: StIt):
     st_x["model"].eval()
     st_it["train_mean_ce"], st_it["train_acc"] = score(
         st_x["model"], train_dl, st_x["dev"]
     )
     st_it["val_mean_ce"], st_it["val_acc"] = score(st_x["model"], val_dl, st_x["dev"])
     with torch.no_grad():
         if "reg_term" in st_it:
             reg_term = st_it["reg_term"]
         else:
             reg_term = calc_regularizer(st_x["model"])
     logger.info(
         f"After {st_it['num_iters_done']:07} iters: "
         f"train/val mean_ce={st_it['train_mean_ce']:.5f}/{st_it['val_mean_ce']:.5f} "
         f"acc={st_it['train_acc']:.2%}/{st_it['val_acc']:.2%} "
         f"{reg_term=:.2e}"
     )
sd = torch.load(MODEL_PATH, map_location=device)
remapped_sd = {
    {
        "0.core": "epses.0",
        "2.core": "epses.1",
        "4.weight": "linear.weight",
        "4.bias": "linear.bias",
    }[key]: value
    for key, value in sd.items()
}

model.load_state_dict(remapped_sd, strict=False)

model.eval()
model.to(device)

train_dl, val_dl, test_dl = get_fashionmnist_data_loaders(
    "/mnt/hdd_1tb/datasets/fashionmnist",
    32,
    device,
    (
        lambda X: 1.45646 * (X * pi / 2.0).sin() ** 2,
        lambda X: 1.45646 * (X * pi / 2.0).cos() ** 2,
    ),
)

# print("train:", score(model, train_dl, device))  # takes too long on CPU
print("val:", score(model, val_dl, device))  # 88.2%
print("test:", score(model, test_dl, device))  # 87.65%
sd = loaded["state_dict"]

print(sd.keys())

model = EPSesPlusLinear(((4, 32), ), UnitTheoreticalOutputStd(), 1.0,
                        torch.device("cpu"), torch.float32)

assert sd["eps.core"].shape == model.epses[0].shape
assert sd["linear.weight"].shape == model.linear.weight.shape
assert sd["linear.bias"].shape == model.linear.bias.shape

model.epses[0].data.copy_(sd["eps.core"])
model.linear.weight.data.copy_(sd["linear.weight"])
model.linear.bias.data.copy_(sd["linear.bias"])

model.eval()
model.to("cpu")

train_dl, val_dl, test_dl = get_mnist_data_loaders(
    "/mnt/hdd_1tb/datasets/mnist",
    128,
    torch.device("cpu"),
    (lambda X: 0.5 * (X * pi / 2.0).sin()**2, lambda X: 0.5 *
     (X * pi / 2.0).cos()**2),
)

print("train:", score(model, train_dl, torch.device("cpu")))
print("val:", score(model, val_dl, torch.device("cpu")))
print("test:", score(model, test_dl, torch.device("cpu")))
from math import pi
from pprint import pprint

import torch

from dctn.eps_plus_linear import EPSesPlusLinear, UnitTheoreticalOutputStd
from dctn.dataset_loading import get_fashionmnist_data_loaders
from dctn.evaluation import score

MODEL_PATH = "/mnt/important/experiments/eps_plus_linear_fashionmnist/replicate_90.19_vacc/2020-05-04T23:13:52_stopped_manually/model_best_val_acc_nitd=0580000_tracc=0.9456_vacc=0.9025_trmce=0.1624_vmce=0.2738.pth"

device = torch.device("cpu")

model = EPSesPlusLinear(((4, 4),), UnitTheoreticalOutputStd(), 1.0, device, torch.float32)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

model.eval()
model.to(device)

train_dl, val_dl, test_dl = get_fashionmnist_data_loaders(
    "/mnt/hdd_1tb/datasets/fashionmnist",
    128,
    device,
    (lambda X: 0.5 * (X * pi / 2.0).sin() ** 2, lambda X: 0.5 * (X * pi / 2.0).cos() ** 2),
)

print("train:", score(model, train_dl, device))
print("val:", score(model, val_dl, device))
print("test:", score(model, test_dl, device))
from dctn.eps_plus_linear import EPSesPlusLinear, UnitTheoreticalOutputStd
from dctn.dataset_loading import get_fashionmnist_data_loaders
from dctn.evaluation import score

MODEL_PATH = "/mnt/important/experiments/3_epses_plus_linear_fashionmnist/2020-05-12T19:33:11_vacc=0.7708_manually_stopped/model_best_val_acc_nitd=0430000_tracc=0.8088_vacc=0.7708_trmce=0.8494_vmce=145.3116.pth"

device = torch.device("cuda:0")

model = EPSesPlusLinear(((4, 4), (3, 12), (2, 24)), UnitTheoreticalOutputStd(),
                        1.0, device, torch.float32)

sd = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(sd)

model.eval()
model.to(device)

train_dl, val_dl, test_dl = get_fashionmnist_data_loaders(
    "/mnt/hdd_1tb/datasets/fashionmnist",
    32,
    device,
    (
        lambda X: 1.45646 * (X * pi / 2.0).sin()**2,
        lambda X: 1.45646 * (X * pi / 2.0).cos()**2,
    ),
)

# print("train:", score(model, train_dl, device))  # takes too long on CPU
print("val:", score(model, val_dl, device))  # 77.08%
print("test:", score(model, test_dl, device))  # 75.94%