示例#1
0
import spaghettini
from spaghettini import load, quick_register

@quick_register
class Composite:
    def __init__(self, a, b):
        self.a = a
        self.b = b(-3, 2)

    def __call__(self, x):
        return self.b(self.a(x))

@quick_register
class Linear:
    def __init__(self, w, b):
        self.w = w
        self.b = b

    def __call__(self, x):
        return self.w * x + self.b

if __name__ == "__main__":
    m = load("assets/test.yaml")
    print(m(2))
示例#2
0
from spaghettini import load
import os
import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PyTorch CIFAR-10 Training")
    parser.add_argument("--cfg",
                        type=str,
                        help="the path to the configuration file")
    parser.add_argument("--resume",
                        "-r",
                        action="store_true",
                        help="resume from checkpoint")
    parser.add_argument("--test",
                        "-t",
                        action="store_true",
                        help="test the model only")
    args = parser.parse_args()

    assert os.path.basename(args.cfg) == "cfg.yaml"
    exp = load(args.cfg)
    if exp.exp_dir is None:
        exp_dir = os.path.dirname(args.cfg)
        print("generating default exp_dir: {}".format(exp_dir))
        exp.register_exp_dir(exp_dir)
    exp.launch(args.resume, cfg_path_source=args.cfg, test_only=args.test)
    print("finished")
    os._exit(0)
示例#3
0
import torch.nn as nn
import spaghettini
from spaghettini import register, quick_register, load, check, check_registered

quick_register(nn.Linear)
quick_register(nn.Linear)
register("relu")(nn.ReLU)
quick_register(nn.Sequential)
print(check())
print(check_registered())

net = load("./examples/assets/pytorch.yaml", verbose=False)
print(net)
示例#4
0
import torch.nn as nn
import spaghettini
from spaghettini import register, quick_register, load, check

quick_register(nn.Linear)
register("relu")(nn.ReLU)
quick_register(nn.Sequential)
print(check())

net = load("assets/pytorch.yaml")
print(net)
import torch.nn as nn
import spaghettini
from spaghettini import register, quick_register, load, check

quick_register(nn.Linear)
quick_register(nn.ReLU)

@quick_register
class MLP(nn.Module):
    def __init__(self, units, activation, linear_module):
        super().__init__()
        model = []
        for index, (in_units, out_units) in enumerate(zip(units[:-1], units[1:])):
            if index != 0:
                model.append(activation)
            model.append(linear_module(in_units, out_units))
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)

print(check())
net = load("assets/pytorch_mlp.yaml")
print(net)
示例#6
0
import torch.nn as nn
import spaghettini
from spaghettini import register, quick_register, load, check, check_registered

quick_register(nn.Linear)
quick_register(nn.ReLU)


@quick_register
class MLP(nn.Module):
    def __init__(self, units, activation, linear_module):
        super().__init__()
        model = []
        for index, (in_units,
                    out_units) in enumerate(zip(units[:-1], units[1:])):
            if index != 0:
                model.append(activation)
            model.append(linear_module(in_units, out_units))
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


print(check())
check_registered()
net = load("examples/assets/pytorch_mlp.yaml", verbose=True)
print("Printing loaded network. ")
print(net)