Esempio n. 1
0
def _create_monitor(datasets, name, network, dataset_names):
    m = nnabla_pb2.Monitor()
    m.name = name
    m.network_name = network.name
    if isinstance(dataset_names, tuple):
        dataset_names = list(dataset_names)
    if isinstance(dataset_names, list):
        for dataset_name in dataset_names:
            if dataset_name in datasets:
                m.dataset_name.append(dataset_name)
                dataset = datasets[dataset_name]
            else:
                raise ValueError(
                    "Invalid dataset name is found in monitor definition: {}".
                    format(dataset_name))
    elif isinstance(dataset_names, str):
        dataset_name = dataset_names
        if dataset_name in datasets:
            m.dataset_name.append(dataset_name)
            dataset = datasets[dataset_name]
    if dataset is None:
        raise ValueError("Dataset is not defined in monitor definition.")
    inputs, outputs, params = _get_net_variables(network)
    for n, inp in enumerate(inputs):
        d = m.data_variable.add()
        d.variable_name = inp
        d.data_name = dataset.variable[n]
    for out in outputs:
        d = m.monitor_variable.add()
        d.type = 'Error'
        d.variable_name = out
    return m
Esempio n. 2
0
def _create_monitor(name, monitor, network, dataset):
    m = nnabla_pb2.Monitor()
    m.name = name
    m.network_name = network.name
    m.dataset_name = dataset.name
    inputs, outputs, params = _get_net_variables(network)
    for n, inp in enumerate(inputs):
        d = m.data_variable.add()
        d.variable_name = inp
        d.data_name = dataset.variable[n]
    for out in outputs:
        d = m.monitor_variable.add()
        d.type = 'Error'
        d.variable_name = out
    return m
Esempio n. 3
0
File: save.py Progetto: sony/nnabla
def _create_monitor(ctx, monitor):
    datasets = ctx.datasets
    if monitor['network'] not in ctx.networks:
        raise ValueError("{} is not found in networks.".format(
            monitor['network']))
    proto_network = ctx.proto_graphs[monitor['network']].default_graph()
    m = nnabla_pb2.Monitor()
    m.name = monitor['name']
    m.network_name = monitor['network']
    if isinstance(monitor['dataset'], (list, tuple)):
        for dataset_name in monitor['dataset']:
            if dataset_name in datasets:
                m.dataset_name.append(dataset_name)
                dataset = datasets[dataset_name]
            else:
                raise ValueError(
                    "Invalid dataset name is found in monitor definition: {}".
                    format(dataset_name))
    elif isinstance(monitor['dataset'], str):
        dataset_name = monitor['dataset']
        if dataset_name in datasets:
            m.dataset_name.append(dataset_name)
            dataset = datasets[dataset_name]
    if dataset is None:
        raise ValueError("Dataset is not defined in monitor definition.")
    for var_name, data_name in monitor.get('data_variables', {}).items():
        d = m.data_variable.add()
        d.variable_name = var_name
        d.data_name = data_name
    for out in monitor.get('monitor_variables', proto_network.outputs):
        d = m.monitor_variable.add()
        d.type = 'Error'
        d.variable_name = out
    for g_var in monitor.get('generator_variables', []):
        d = m.generator_variable.add()
        d.variable_name = g_var
        d.type = 'Constant'
        d.multiplier = 0
    return m