コード例 #1
0
ファイル: misc.py プロジェクト: haowen-xu/tensorkit
def print_experiment_summary(exp: mltk.Experiment,
                             train_data: Any,  # anything that has '__len__'
                             val_data: Optional[Any] = None,
                             test_data: Optional[Any] = None):
    # the config
    mltk.print_config(exp.config)
    print('')

    # the dataset info
    data_info = []
    for name, data in [('Train', train_data), ('Validation', val_data),
                         ('Test', test_data)]:
        if data is not None:
            data_info.append((name, len(data)))
    if data_info:
        print(mltk.format_key_values(data_info, 'Number of Data'))
        print('')

    # the device info
    device_info = [
        ('Current', T.current_device())
    ]
    gpu_devices = T.gpu_device_list()
    if gpu_devices:
        device_info.append(('Available', gpu_devices))
    print(mltk.format_key_values(device_info, 'Device Info'))
    print('')
コード例 #2
0
ファイル: misc.py プロジェクト: lizeyan/tensorkit
def print_parameters_summary(params: List[T.Variable],
                             names: List[str],
                             printer: Optional[Callable[[str], Any]] = print):
    shapes = []
    sizes = []
    total_size = 0
    max_shape_len = 0
    max_size_len = 0
    right_pad = ' ' * 3

    for param in params:
        shape = T.shape(param)
        size = np.prod(shape)
        total_size += size
        shapes.append(str(shape))
        sizes.append(f'{size:,d}')
        max_shape_len = max(max_shape_len, len(shapes[-1]))
        max_size_len = max(max_size_len, len(sizes[-1]))

    total_size = f'{total_size:,d}'
    right_len = max(max_shape_len + len(right_pad) + max_size_len,
                    len(total_size))

    param_info = []
    max_name_len = 0
    for param, name, shape, size in zip(params, names, shapes, sizes):
        max_name_len = max(max_name_len, len(name))
        right = f'{shape:<{max_shape_len}s}{right_pad}{size:>{max_size_len}s}'
        right = f'{right:>{right_len}s}'
        param_info.append((name, right))

    if param_info:
        param_info.append(('Total', f'{total_size:>{right_len}s}'))
        lines = mltk.format_key_values(param_info,
                                       title='Parameters',
                                       formatter=str).strip().split('\n')
        k = len(lines[-1])
        lines.insert(-1, '-' * k)

        printer('\n'.join(lines))
コード例 #3
0
 def test_format_key_values(self):
     with pytest.raises(ValueError,
                        match='`delimiter_char` must be one character: '
                        'got \'xx\''):
         format_key_values({'a': 1}, delimiter_char='xx')