Пример #1
0
def collect_torch_weights(output_dir, config, num_layers):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model = load_torch_model(config)
    model_e = model.eval()

    data_to_compare = OrderedDict()

    register_print_hooks(output_dir,
                         model_e,
                         num_layers=num_layers,
                         data_to_compare=data_to_compare,
                         dump_activations=False)
    input_ = randn(config['input_sample_size'])
    model_e(input_)

    for _, module in enumerate(model_e.modules()):
        paths = get_full_dump_paths(module)
        if paths is not None:
            for dump_path in paths:
                if os.path.isfile(dump_path):
                    data_to_compare[os.path.splitext(
                        os.path.basename(dump_path))[0]] = np.load(dump_path)
    return data_to_compare
Пример #2
0
def validate_torch_model(output_dir, config, num_layers, dump, val_loader=None, cuda=False):
    from tools.debug.common import load_torch_model, register_print_hooks

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    model = load_torch_model(config, cuda)

    model_e = model.eval()
    if dump:
        register_print_hooks(output_dir, model_e, num_layers=num_layers, data_to_compare=None, dump_activations=True)

    validate_general(val_loader, model_e, infer_pytorch_model, cuda)