def main(
    args,
    epochs: int = 14,
    no_cuda: bool = False,
    seed: int = 1,
    save_model: bool = False,
):
    """Runs an MNIST classification experiment.

    Parameters
    ----------
    epochs : int, optional
        number of epochs to train, by default 14
    no_cuda : bool, optional
        disables CUDA training, by default False
    seed : int, optional
        random seed, by default 1
    save_model : bool, optional
        For Saving the current Model, by default False
    """
    use_cuda = not no_cuda and torch.cuda.is_available()
    torch.manual_seed(seed)
    device = torch.device("cuda" if use_cuda else "cpu")

    with argbind.scope(args, 'train'):
        train_loader = dataset(device)
    with argbind.scope(args, 'test'):
        test_loader = dataset(device)

    model = Net().to(device)
    _optimizer = optimizer(model)
    _scheduler = scheduler(_optimizer)

    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, _optimizer, epoch)
        test(model, device, test_loader)
        _scheduler.step()

    if save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
    results_folder : str, optional
        Folder where results are, by default './results'
    """
    print("STAGE: ANALYZE")
    print(f"Generating plots for {results_folder}")
    (Path(results_folder) / 'example.png').touch()
    print()


@argbind.bind(without_prefix=True, positional=True)
def run(stage: str):
    """Run stages.

    Parameters
    ----------
    stages : str
        Stage to run
    """
    with output():
        if stage not in STAGES:
            raise ValueError(
                f"Requested stage {stage} not in known stages {STAGES}")
        stage_fn = globals()[stage]
        stage_fn()


if __name__ == "__main__":
    args = argbind.parse_args()
    with argbind.scope(args):
        run()
Example #3
0
def forward():
    run = argbind.bind(open_link, positional=True, without_prefix=True)
    args = argbind.parse_args()
    with argbind.scope(args):
        run()
Example #4
0
import argbind


@argbind.bind()
def func(arg: str = 'default'):
    print(arg)


dict1 = {
    'func.arg': 1,
}
dict2 = {'func.arg': 2}

with argbind.scope(dict1):
    func()  # prints 1
with argbind.scope(dict2):
    func()  # prints 2
func(arg=3)  # prints 3.

if __name__ == "__main__":
    args = argbind.parse_args()
    with argbind.scope(args):
        func()

    with argbind.scope(args):
        func(arg=3)
Example #5
0
import argbind

@argbind.bind('train', 'val', 'test')
def dataset(
    some_positional_arg,
    folder : str = 'default',
):
    """Creates a dataset.

    Parameters
    ----------
    some_positional_arg 
        Some positional argument that gets passed in via the script.
    folder : str, optional
        Folder for the dataset, by default 'default'
    """
    print(folder)

if __name__ == "__main__":
    args = argbind.parse_args()
    some_positional_arg = None
    with argbind.scope(args):
        dataset(some_positional_arg)
    for scope in ['train', 'val', 'test']:
        with argbind.scope(args, scope):
            dataset(some_positional_arg)
Example #6
0
def main():
    args = argbind.parse_args()
    with argbind.scope(args):
        arg = Args()
        signal.signal(signal.SIGINT, clean_up)
        app.run(debug=True, host=arg.host, port=arg.port)
Example #7
0
    
    BoundClass = argbind.bind(MyClass, 'pattern')
    bound_fn = argbind.bind(my_func)

    argbind.parse_args() # add for help text, though it isn't used here.

    args = {
      'MyClass.x': 'from binding',
      'pattern/MyClass.x': 'from binding in scoping pattern',
      'my_func.x': 123,
      'args.debug': True # for printing arguments passed to each function
    }

    # Original objects are not affected by ArgBind
    print("Original object output")
    with argbind.scope(args):
        MyClass() # prints "from default"
        my_func() # prints 100
    print()
    
    # Bound objects ARE affected by ArgBind
    print("Bound objects output")
    with argbind.scope(args):
        BoundClass() # prints "from binding"
        bound_fn() # prints 123
    print()
    
    # Scoping patterns can be used
    print("Bound objects inside scoping pattern output")
    with argbind.scope(args, 'pattern'):
        BoundClass() # prints "from binding in scoping pattern"