Ejemplo n.º 1
0
def test_value_error_on_unknown_model_getter():
    application = Application()
    application.add(InitCommand())

    module_name = "foo"
    module_file_path = module_name + ".py"
    save_path = "foo-config/model.py"
    content = """
import torch

def init(**args):
    return torch.nn.Conv2d(10, 10, 10)

"""
    with open(module_file_path, "a") as the_file:
        the_file.write(content)

    command = application.find("init")
    command_tester = CommandTester(command)

    with pytest.raises(ValueError,
                       match=r".*{fn_name}.*".format(fn_name="get_model")):
        command_tester.execute(module_name + " --save=" + save_path)

    os.remove(module_file_path)
    shutil.rmtree(os.path.dirname(save_path))
Ejemplo n.º 2
0
def test_value_error_on_unknown_module():
    application = Application()
    application.add(InitCommand())

    command = application.find("init")
    command_tester = CommandTester(command)

    unknown_module_name = "some_module"
    with pytest.raises(ValueError, match=r".*%s.*" % unknown_module_name):
        command_tester.execute(unknown_module_name)
Ejemplo n.º 3
0
def test_init_error_custom_config_with_list():
    application = Application()
    application.add(InitCommand())
    invalid_config = '[{"conf_key": "value"}]'
    module_name = "foo"
    command = application.find("init")
    command_tester = CommandTester(command)

    with pytest.raises(ValueError):
        command_tester.execute(module_name + " --config '" + invalid_config +
                               "'")
Ejemplo n.º 4
0
def test_value_error_on_invalid_config_with_decoding():
    application = Application()
    application.add(InitCommand())
    invalid_config = '{"unclosed_key: 15}'
    module_name = "foo"
    command = application.find("init")
    command_tester = CommandTester(command)

    with pytest.raises(ValueError,
                       match=r".*{reason}.*".format(reason="invalid")):
        command_tester.execute(module_name + " --config '" + invalid_config +
                               "'")
Ejemplo n.º 5
0
def test_value_error_on_invalid_config_with_single_quotes():
    application = Application()
    application.add(InitCommand())
    invalid_config = "{'output_size': 10}"
    module_name = "foo"
    command = application.find("init")
    command_tester = CommandTester(command)

    with pytest.raises(ValueError,
                       match=r".*{reason}.*".format(reason="double quotes")):
        command_tester.execute(module_name + ' --config "' + invalid_config +
                               '"')
Ejemplo n.º 6
0
def test_success_load_custom_user_config():
    application = Application()
    application.add(InitCommand())
    custom_config = {
        "batch_size": 512,
        "device": "cuda:0",
        "dataset": "torch.blah",
        "learning_rate": 0.2,
    }
    # set up file path variables
    module_name = "tmp_module1_init"
    module_file_path = module_name + ".py"
    save_path = "tmp-test-save/model.py"

    # clean up possible existing files
    if os.path.exists(module_file_path):
        os.remove(module_file_path)
    if os.path.exists(save_path):
        shutil.rmtree(os.path.dirname(save_path))

    # write model to file path from which we want to import from
    content = """
import torch

def get_model(**args):
    return torch.nn.Conv2d(10, 10, 10)

"""
    with open(module_file_path, "a") as the_file:
        the_file.write(content)

    command = application.find("init")
    command_tester = CommandTester(command)

    # Act
    command_tester.execute(module_name + " --config {json_config}".format(
        json_config=shlex.quote(json.dumps(custom_config))) + " --save=" +
                           save_path)

    assert os.path.exists(save_path)

    # Load metadata from saved path
    # TODO schema validation
    # reader = pkmd.MetadataReader()
    # reader.read(os.path.join(os.path.dirname(save_path), 'config.json'))

    # Cleanup
    os.remove(module_file_path)
    shutil.rmtree(os.path.dirname(save_path))
Ejemplo n.º 7
0
def test_success_init_simple_model():
    # Arrange
    # set up application with command
    application = Application()
    application.add(InitCommand())

    # set up file path variables
    module_name = "bar"
    module_file_path = module_name + ".py"
    save_path = "bar-config/model.py"

    # clean up possible existing files
    if os.path.exists(module_file_path):
        os.remove(module_file_path)
    if os.path.exists(save_path):
        shutil.rmtree(os.path.dirname(save_path))

    # write model to file path from which we want to import from
    content = """
import torch

def get_model(**args):
    return torch.nn.Conv2d(10, 10, 10)

"""
    with open(module_file_path, "a") as the_file:
        the_file.write(content)

    # load command and build tester object to act on
    command = application.find("init")
    command_tester = CommandTester(command)

    # Act
    command_tester.execute(module_name + " --save=" + save_path)

    # Cleanup
    os.remove(module_file_path)
    shutil.rmtree(os.path.dirname(save_path))
Ejemplo n.º 8
0
def test_success_init_simple_model():
    # Arrange
    # set up application with command
    application = Application()
    application.add(InitCommand())
    application.add(TrainCommand())

    # set up file path variables
    module_name = "module_name_for_training"
    module_file_path = module_name + ".py"
    save_path = "tmp-train-save-path/model.py"
    dataset_module = "tmp_training_dataset"
    dataset_module_file_path = dataset_module + ".py"

    user_config = {"num_epochs": 3}  # Use only few epochs for test

    # clean up possible existing files
    if os.path.exists(module_file_path):
        os.remove(module_file_path)
    if os.path.exists(dataset_module_file_path):
        os.remove(dataset_module_file_path)
    if os.path.exists(save_path):
        shutil.rmtree(os.path.dirname(save_path))

    # write model to file path from which we want to import from
    content_model = """
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self, width: int, height: int):
        super(MyModel, self).__init__()
        c = 6  # intermediate_channels
        self.conv = nn.Conv2d(in_channels=3, out_channels=c, kernel_size=5)
        self.fc = nn.Linear((width-5+1)*(height-5+1)*c, 10)

    def forward(self, x):
         out = F.relu(self.conv(x))
         out = out.view(out.size(0), -1)
         return F.relu(self.fc(out))


def get_model(**args):
    return MyModel(width=32, height=32)

"""
    with open(module_file_path, "a") as model_handle:
        model_handle.write(content_model)

    # write model to file path from which we want to import from
    content_dataset = """
import torch
import numpy as np
from torch.utils import data


class MyDataset(data.Dataset):
    def __len__(self):
        return 200

    def __getitem__(self, index):
        return torch.rand((3, 32, 32)), np.random.randint(0, 10)


def get_dataset(**args):
    return MyDataset()

"""
    with open(dataset_module_file_path, "a") as dataset_handle:
        dataset_handle.write(content_dataset)

    command_init = application.find("init")
    init_tester = CommandTester(command_init)
    init_tester.execute(module_name + " --save=" + save_path)

    command_train = application.find("train")
    train_tester = CommandTester(command_train)
    train_tester.execute(
        save_path
        + " "
        + dataset_module
        + ".get_dataset"
        + " --config {json_config}".format(
            json_config=shlex.quote(json.dumps(user_config))
        )
    )

    # Cleanup
    os.remove(module_file_path)
    os.remove(dataset_module_file_path)
    shutil.rmtree(os.path.dirname(save_path))