def main():
 #   model = VGG(depth=16, init_weights=True, cfg=None)
    model = MobileNetV2()
    model = dataParallel_converter(model, "./cifar100_mobilenetv217_retrained_acc_80.170_config_mobile_v2_0.7_threshold.pt")


    aa = getattr(model, 'conv1')

    input_size = 32
    input_shape = (1, 3, input_size, input_size)

    all_ops = read_model(model, input_shape)
    print('\n'.join(map(str, all_ops)))


    counter = counting.MicroNetCounter(all_ops, add_bits_base=32, mul_bits_base=32)


    INPUT_BITS = 16
    ACCUMULATOR_BITS = 32
    PARAMETER_BITS = 16
    SUMMARIZE_BLOCKS = True

    counter.print_summary(0, PARAMETER_BITS, ACCUMULATOR_BITS, INPUT_BITS, summarize_blocks=SUMMARIZE_BLOCKS)

    counter.print_summary(0.5, PARAMETER_BITS, ACCUMULATOR_BITS, INPUT_BITS, summarize_blocks=SUMMARIZE_BLOCKS)
예제 #2
0

args = (collections.namedtuple('Args', ['bptt', 'cuda']))(bptt=140, cuda=True)

# load model
model_load('WT103.12hr.QRNN.pt')

# load test data: read vocab, process test text
corpus = torch.load('corpus-wikitext-103.vocab-only.data')
test_data = corpus.tokenize('data/wikitext-103/test.txt')
test_data = batchify(test_data, 1, args)
# Run on test data.
test_loss = evaluate(test_data, 1)
print('=' * 89)
print('Test ppl {:8.2f} '.format(math.exp(test_loss)))
print('=' * 89)

# read model ops
ops = read_model(model)
# print model MFLOPS and #Parameters
counter = counting.MicroNetCounter(ops, add_bits_base=32, mul_bits_base=32)
INPUT_BITS = 16
ACCUMULATOR_BITS = 32
PARAMETER_BITS = INPUT_BITS
SUMMARIZE_BLOCKS = True
counter.print_summary(0,
                      PARAMETER_BITS,
                      ACCUMULATOR_BITS,
                      INPUT_BITS,
                      summarize_blocks=SUMMARIZE_BLOCKS)