コード例 #1
0
ファイル: run_pt.py プロジェクト: slyubomirsky/relay-bench-1
    return (net, image_shape)


def cnn_setup(network, dev, batch_size):
    net, image_shape = instantiate_network(network, batch_size, dev)
    device = torch.device(
        'cuda' if dev == 'gpu' and torch.cuda.is_available() else 'cpu')

    target = net.to(device)
    input_tensor = np.random.randn(*image_shape).astype(np.float32)
    input = torch.autograd.Variable(torch.from_numpy(input_tensor))
    input = input.to(device)
    return [target, input]


def cnn_trial(target, input):
    return target(input)


def cnn_teardown(target, input):
    pass


if __name__ == '__main__':
    run_template(validate_config=validate,
                 check_early_exit=common_early_exit({'frameworks': 'pt'}),
                 gen_trial_params=common_trial_params(
                     'pt', 'cnn_comp', cnn_trial, cnn_setup, cnn_teardown,
                     ['network', 'device', 'batch_size'],
                     ['networks', 'devices', 'batch_sizes']))
コード例 #2
0
from validate_config import validate
from exp_templates import common_trial_params, run_template
from relay_util import cnn_setup, cnn_trial, cnn_teardown

if __name__ == '__main__':
    run_template(validate_config=validate,
                 gen_trial_params=common_trial_params(
                     'relay', 'opt_comparison', cnn_trial, cnn_setup,
                     cnn_teardown,
                     ['network', 'device', 'batch_size', 'opt_level'],
                     ['networks', 'devices', 'batch_sizes', 'opt_levels']))
コード例 #3
0
from validate_config import validate
from exp_templates import (common_trial_params, common_early_exit,
                           run_template)

from language_data import N_LETTERS
from pt_rnn import RNN, samples


def rnn_setup(device, hidden_size, lang, letters):
    rnn = RNN(N_LETTERS, hidden_size, N_LETTERS)
    return [lambda: samples(rnn, lang, letters)]


def rnn_trial(thunk):
    return thunk()


def rnn_teardown(thunk):
    pass


if __name__ == '__main__':
    run_template(validate_config=validate,
                 check_early_exit=common_early_exit({'frameworks': 'pt'}),
                 gen_trial_params=common_trial_params(
                     'pt', 'char_rnn',
                     rnn_trial, rnn_setup, rnn_teardown,
                     ['device', 'hidden_size', 'language', 'input'],
                     ['devices', 'hidden_sizes', 'languages', 'inputs']))
コード例 #4
0
    net, num_states, shapes = import_gluon_rnn(network)
    net.initialize(ctx=context)
    net.hybridize()

    shape_list = [shapes['data']
                  ] + [shapes['state%s' % i] for i in range(num_states)]
    mx_inputs = [
        mx.nd.array(np.random.rand(*shape).astype('float32'), ctx=context)
        for shape in shape_list
    ]
    return [lambda: net(*mx_inputs)[0].asnumpy()]


def rnn_trial(thunk):
    thunk()


def rnn_teardown(thunk):
    pass


if __name__ == '__main__':
    run_template(validate_config=validate,
                 check_early_exit=common_early_exit({'frameworks': 'mxnet'}),
                 gen_trial_params=common_trial_params('mxnet', 'gluon_rnns',
                                                      rnn_trial, rnn_setup,
                                                      rnn_teardown,
                                                      ['device', 'network'],
                                                      ['devices', 'networks']))
コード例 #5
0
    return cnn_setup(network,
                     dev,
                     batch_size,
                     opt_level,
                     use_passes=True,
                     passes=pass_list)


def gen_trial_params(config):
    # We must preprocess the passes to work with passes_setup.
    # I.e., we must serialize it so it can be written to CSV,
    # so we separate the pass list by |'s and the opt_level with
    # a semicolon
    passes = [
        ';'.join([str(pass_spec[0]), '|'.join(pass_spec[1])])
        for pass_spec in config['passes']
    ]

    return [
        'relay', 'pass_comparison', config['dry_run'],
        config['n_times_per_input'], config['n_inputs'], cnn_trial,
        passes_setup, cnn_teardown,
        ['network', 'device', 'batch_size', 'pass_spec'],
        [config['networks'], config['devices'], config['batch_sizes'], passes]
    ]


if __name__ == '__main__':
    run_template(validate_config=validate, gen_trial_params=gen_trial_params)