class TestModel(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cifar_data = nd.random.normal(shape=(1, 3, 32, 32)) self.cifar_att56 = ResidualAttentionModel_32input() self.cifar_att56.initialize() self.cifar_att92 = ResidualAttentionModel_32input( additional_stage=True) self.cifar_att92.initialize() self.imgnet_data = nd.random.normal(shape=(1, 3, 224, 224)) self.att56 = ResidualAttentionModel() self.att56.initialize() self.att92 = ResidualAttentionModel(additional_stage=True) self.att92.initialize() def test_model(self): self.assertEqual((1, 10), self.cifar_att56(self.cifar_data).shape) self.assertEqual((1, 10), self.cifar_att92(self.cifar_data).shape) self.assertEqual((1, 1000), self.att56(self.imgnet_data).shape) self.assertEqual((1, 1000), self.att92(self.imgnet_data).shape)
if __name__ == '__main__': batch_size = 64 iterations = 530e3 wd = 1e-4 lr = 0.1 lr_period = tuple([iterations * i for i in (0.3, 0.6, 0.9)]) lr_decay = 0.1 cat_interval = 10e3 num_workers = 12 num_gpus = 2 ctx = [mx.gpu(i) for i in range(num_gpus)] net = ResidualAttentionModel() net.hybridize(static_alloc=True) net.initialize(init=mx.init.MSRAPrelu(), ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'nag', { 'learning_rate': lr, 'momentum': 0.9, 'wd': wd }) train_data = gluon.data.DataLoader(ImageFolderDataset( '/system1/Dataset/ImageNet/ILSVRC2012_img_train', transform=transformer), batch_size=batch_size, shuffle=True, num_workers=num_workers, last_batch='discard')