def main(): # model = VGG(depth=16, init_weights=True, cfg=None) model = MobileNetV2() model = dataParallel_converter(model, "./cifar100_mobilenetv217_retrained_acc_80.170_config_mobile_v2_0.7_threshold.pt") aa = getattr(model, 'conv1') input_size = 32 input_shape = (1, 3, input_size, input_size) all_ops = read_model(model, input_shape) print('\n'.join(map(str, all_ops))) counter = counting.MicroNetCounter(all_ops, add_bits_base=32, mul_bits_base=32) INPUT_BITS = 16 ACCUMULATOR_BITS = 32 PARAMETER_BITS = 16 SUMMARIZE_BLOCKS = True counter.print_summary(0, PARAMETER_BITS, ACCUMULATOR_BITS, INPUT_BITS, summarize_blocks=SUMMARIZE_BLOCKS) counter.print_summary(0.5, PARAMETER_BITS, ACCUMULATOR_BITS, INPUT_BITS, summarize_blocks=SUMMARIZE_BLOCKS)
args = (collections.namedtuple('Args', ['bptt', 'cuda']))(bptt=140, cuda=True) # load model model_load('WT103.12hr.QRNN.pt') # load test data: read vocab, process test text corpus = torch.load('corpus-wikitext-103.vocab-only.data') test_data = corpus.tokenize('data/wikitext-103/test.txt') test_data = batchify(test_data, 1, args) # Run on test data. test_loss = evaluate(test_data, 1) print('=' * 89) print('Test ppl {:8.2f} '.format(math.exp(test_loss))) print('=' * 89) # read model ops ops = read_model(model) # print model MFLOPS and #Parameters counter = counting.MicroNetCounter(ops, add_bits_base=32, mul_bits_base=32) INPUT_BITS = 16 ACCUMULATOR_BITS = 32 PARAMETER_BITS = INPUT_BITS SUMMARIZE_BLOCKS = True counter.print_summary(0, PARAMETER_BITS, ACCUMULATOR_BITS, INPUT_BITS, summarize_blocks=SUMMARIZE_BLOCKS)