def main():

    params.print_params()
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_test',
                        type=str2bool,
                        default=False,
                        help='local test verbose')
    parser.add_argument('--aug',
                        type=str2bool,
                        default=False,
                        help='source domain id')
    args = parser.parse_args()

    X_train, Y_train, X_val, Y_val, X_test, Y_test = loadData(
        params.data_folder)
    print(X_train.shape, Y_train.shape)
    print(X_val.shape, Y_val.shape)
    print(X_test.shape, Y_test.shape)

    model, avg_layer = ResNet(X_train, Y_train, X_test, Y_test, args)
    pred = model.predict(X_test)
    acc = evaluate(pred, Y_test)
    print('\nAccuracy: {:.4f}'.format(acc))
        else:
            from VGG_predict import make_predictions
            from models.VGG_16 import VGG_16_test, VGG_19_test
            
            if args.version == '16':
                make_predictions(args.img_dim, VGG_16_test)
            elif args.version =='19':
                make_predictions(args.img_dim, VGG_19_test)
            else:
                sys.exit('cannot find model you have specified')
    
    elif args.model == 'ResNet':
        import models.ResNet as ResNet

        train_data, train_labels, valid_data, valid_labels, test_data, test_ids = preprocess.get_roof_data(augmented=True, shape=(64, 64))

        if args.train:

            ResNet.train(train_data, train_labels, valid_data, valid_labels, dropout=0.62, num_blocks=3, lr=0.007, weight_decay=0.004)

        else:
            model_vargs = dict(dropout=0.62, num_blocks=3)
            fn = 'results/best.model'
            valid_predictions = ResNet.predict(fn, model_vargs, valid_data)
            test_predictions = ResNet.predict(fn, model_vargs, test_data)
            make_prediction_file.make_prediction_file(test_ids, test_predictions, 'Resnet805_64_64', valid_labels=valid_labels, valid_predictions=valid_predictions)