Пример #1
0
    dtype='float32'
    mult = 1.

    if args.lr_schedule=='step':
        args.lr_schedule = {0: 0.01*mult, 5: 0.1*mult, 95: 0.01*mult, 105: 0.001*mult}

    
    if args.model =='wrn':
        model = wrn(num_classes=10)
    elif args.model =='resnet18':
        model = resnet18Basic(num_classes=10)        
    else:
        logging.error("Model not currently supported.")
        sys.exit(0)
        
    learner = GluonLearner(model, hybridize=False, ctx=[mx.gpu(0)])
    learner.fit(train_data=train_data,
                valid_data=valid_data,
                epochs=args.epochs,
                lr_schedule=args.lr_schedule,
                initializer=mx.init.Xavier(rnd_type='gaussian', factor_type='out', magnitude=2),
                optimizer=mx.optimizer.NAG(learning_rate=0.1, rescale_grad=1.0/batch_size, momentum=0.9, wd=0.0005),
                early_stopping_criteria=lambda e: e >= 0.94,
                kvstore=kvstore,
                dtype=dtype,) 

    
#     _, test_data = Cifar10(batch_size=1, data_shape=(3, 32, 32),
#                            normalization_type="channel").return_dataloaders()
#     learner.predict(test_data=test_data, log_frequency=100)
Пример #2
0
    mx.random.seed(args.seed)

    batch_size = 128
    train_data, valid_data = Cifar10(
        batch_size=batch_size,
        data_shape=(3, 32, 32),
        padding=4,
        padding_value=0,
        normalization_type="channel").return_dataloaders()

    lr_schedule = {0: 0.01, 5: 0.1, 95: 0.01, 140: 0.001}

    model = resnet164Basic(num_classes=10)

    learner = GluonLearner(model,
                           run_id,
                           gpu_idxs=args.gpu_idxs,
                           hybridize=True)
    learner.fit(train_data=train_data,
                valid_data=valid_data,
                epochs=185,
                lr_schedule=lr_schedule,
                initializer=mx.init.Xavier(rnd_type='gaussian',
                                           factor_type='out',
                                           magnitude=2),
                optimizer=mx.optimizer.SGD(learning_rate=lr_schedule[0],
                                           rescale_grad=1.0 / batch_size,
                                           momentum=0.9,
                                           wd=0.0005),
                early_stopping_criteria=lambda e: e >= 0.94
                )  # DAWNBench CIFAR-10 criteria
Пример #3
0
if __name__ == "__main__":
    run_id = construct_run_id(__file__)
    configure_root_logger(run_id)
    logging.info(__file__)

    args = process_args()
    mx.random.seed(args.seed)

    _, test_data = Cifar10(batch_size=1,
                           data_shape=(3, 32, 32),
                           normalization_type="channel").return_dataloaders()

    # download model symbol and params (if doesn't already exist)
    filename = "resnet164_basic_gluon.params"
    folder = os.path.realpath(
        os.path.join(os.path.dirname(os.path.realpath(__file__)),
                     "../logs/checkpoints/"))
    filepath = os.path.join(folder, filename)
    if not os.path.exists(filepath):
        os.system("aws s3 cp s3://benchmark-ai-models/{} {}".format(
            filename, folder))
        logging.info("Downloading {} to {}".format(filename, folder))

    model = resnet164Basic(num_classes=10)
    learner = GluonLearner(model,
                           run_id,
                           gpu_idxs=args.gpu_idxs,
                           hybridize=False)
    learner.load(filename="resnet164_basic_gluon.params")
    learner.predict(test_data=test_data, log_frequency=100)