dtype='float32' mult = 1. if args.lr_schedule=='step': args.lr_schedule = {0: 0.01*mult, 5: 0.1*mult, 95: 0.01*mult, 105: 0.001*mult} if args.model =='wrn': model = wrn(num_classes=10) elif args.model =='resnet18': model = resnet18Basic(num_classes=10) else: logging.error("Model not currently supported.") sys.exit(0) learner = GluonLearner(model, hybridize=False, ctx=[mx.gpu(0)]) learner.fit(train_data=train_data, valid_data=valid_data, epochs=args.epochs, lr_schedule=args.lr_schedule, initializer=mx.init.Xavier(rnd_type='gaussian', factor_type='out', magnitude=2), optimizer=mx.optimizer.NAG(learning_rate=0.1, rescale_grad=1.0/batch_size, momentum=0.9, wd=0.0005), early_stopping_criteria=lambda e: e >= 0.94, kvstore=kvstore, dtype=dtype,) # _, test_data = Cifar10(batch_size=1, data_shape=(3, 32, 32), # normalization_type="channel").return_dataloaders() # learner.predict(test_data=test_data, log_frequency=100)
mx.random.seed(args.seed) batch_size = 128 train_data, valid_data = Cifar10( batch_size=batch_size, data_shape=(3, 32, 32), padding=4, padding_value=0, normalization_type="channel").return_dataloaders() lr_schedule = {0: 0.01, 5: 0.1, 95: 0.01, 140: 0.001} model = resnet164Basic(num_classes=10) learner = GluonLearner(model, run_id, gpu_idxs=args.gpu_idxs, hybridize=True) learner.fit(train_data=train_data, valid_data=valid_data, epochs=185, lr_schedule=lr_schedule, initializer=mx.init.Xavier(rnd_type='gaussian', factor_type='out', magnitude=2), optimizer=mx.optimizer.SGD(learning_rate=lr_schedule[0], rescale_grad=1.0 / batch_size, momentum=0.9, wd=0.0005), early_stopping_criteria=lambda e: e >= 0.94 ) # DAWNBench CIFAR-10 criteria
if __name__ == "__main__": run_id = construct_run_id(__file__) configure_root_logger(run_id) logging.info(__file__) args = process_args() mx.random.seed(args.seed) _, test_data = Cifar10(batch_size=1, data_shape=(3, 32, 32), normalization_type="channel").return_dataloaders() # download model symbol and params (if doesn't already exist) filename = "resnet164_basic_gluon.params" folder = os.path.realpath( os.path.join(os.path.dirname(os.path.realpath(__file__)), "../logs/checkpoints/")) filepath = os.path.join(folder, filename) if not os.path.exists(filepath): os.system("aws s3 cp s3://benchmark-ai-models/{} {}".format( filename, folder)) logging.info("Downloading {} to {}".format(filename, folder)) model = resnet164Basic(num_classes=10) learner = GluonLearner(model, run_id, gpu_idxs=args.gpu_idxs, hybridize=False) learner.load(filename="resnet164_basic_gluon.params") learner.predict(test_data=test_data, log_frequency=100)