コード例 #1
0
from data_loaders.cifar10 import Cifar10
from models.resnet164_basic import resnet164Basic
from learners.gluon import GluonLearner

if __name__ == "__main__":
    run_id = construct_run_id(__file__)
    configure_root_logger(run_id)
    logging.info(__file__)

    args = process_args()
    mx.random.seed(args.seed)

    batch_size = 128
    train_data, valid_data = Cifar10(
        batch_size=batch_size,
        data_shape=(3, 32, 32),
        padding=4,
        padding_value=0,
        normalization_type="channel").return_dataloaders()

    lr_schedule = {0: 0.01, 5: 0.1, 95: 0.01, 140: 0.001}

    model = resnet164Basic(num_classes=10)

    learner = GluonLearner(model,
                           run_id,
                           gpu_idxs=args.gpu_idxs,
                           hybridize=True)
    learner.fit(train_data=train_data,
                valid_data=valid_data,
                epochs=185,
                lr_schedule=lr_schedule,
コード例 #2
0
from arg_parsing import process_args
from logger import construct_run_id, configure_root_logger
from data_loaders.cifar10 import Cifar10
from models.resnet164_basic import resnet164Basic
from learners.module import ModuleLearner


if __name__ == "__main__":
    run_id = construct_run_id(__file__)
    configure_root_logger(run_id)
    logging.info(__file__)

    args = process_args()
    mx.random.seed(args.seed)

    _, test_data = Cifar10(batch_size=1, data_shape=(3, 32, 32),
                           normalization_type="channel").return_dataiters()

    # download model symbol and params (if doesn't already exist)
    for filename in ["resnet164_basic_module-0000.params", "resnet164_basic_module-symbol.json"]:
        folder = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../logs/checkpoints/"))
        filepath = os.path.join(folder, filename)
        if not os.path.exists(filepath):
            os.system("aws s3 cp s3://benchmark-ai-models/{} {}".format(filename, folder))
            logging.info("Downloading {} to {}".format(filename, folder))

    model = resnet164Basic(num_classes=10)
    learner = ModuleLearner(model, run_id, gpu_idxs=args.gpu_idxs)
    learner.load(prefix="resnet164_basic_module", data_iter=test_data)
    learner.predict(test_data=test_data, log_frequency=100)