예제 #1
0
 def test_get_model_flops(self):
     # to suppress userwarning for upsample
     with _warnings.catch_warnings():
         _warnings.filterwarnings("ignore", category=UserWarning)
         model = _CoverageNetwork()
         samplenet_flops = 4786051
         flops = _tu.get_model_flops(model, _torch.rand((1, 3, 28, 28)))
         self.assertAlmostEqual(flops, samplenet_flops)
예제 #2
0
vocab = infos['vocab'] # ix -> word mapping

opt.vocab = vocab
model = models.setup(opt)
#cocotest_bu_fc = np.load('/home/zzgyf/github_yifan/ImageCaptioning.pytorch/data/cocotest_bu_fc/36184.npy')
#print('cocotest_bu_fc is{}'.format(cocotest_bu_fc))
#print('the type of cocotest_bu_fc is {}'.format(type(cocotest_bu_fc)))
#print('the size of cocotest_bu_fc is {}'.format((cocotest_bu_fc.shape)))
#cocotest_bu_att = np.load('/home/zzgyf/github_yifan/ImageCaptioning.pytorch/data/cocotest_bu_att/36184.npz')
cocotest_bu_fc = torch.randn(10, 2048)
cocotest_bu_att = torch.randn(10, 0, 0)
labels = torch.randint(5200, (10, 5, 18))
masks = torch.randint(1,(10, 5, 18))
#model = torchvision.models.alexnet()
# calculate model FLOPs
model.train(False)
model.eval()
total_flops = tu.get_model_flops(model, cocotest_bu_fc, cocotest_bu_att, labels, masks)
print('Total model FLOPs: {:,}'.format(total_flops/pow(10,9)))


# calculate total model parameters
total_params = tu.get_model_param_count(model)
print('Total model params: {:,}'.format(total_params/pow(10,6)))





예제 #3
0
 def test_get_model_flops_rnn(self):
     model = _SequenceNetwork(mode='rnn')
     lstm_flops = 12479200
     sentence = _torch.randint(1, 10, (200, )).long()
     flops = _tu.get_model_flops(model, [sentence], [len(sentence)])
     self.assertAlmostEqual(flops, lstm_flops)