def prune(args, model, sparsity):

    pruning_params = return_pruning_params(sparsity, args.pruning_epochs,
                                           args.sparsity_var,
                                           args.fixed_sparsity)

    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(
        model, **pruning_params)
    genTrain, genValidation, genTest = createSplitGenerators(
        args.inputFiles, generatorOptions, shuffle=False, shuffleChunks=False)
    model_for_pruning.compile(
        optimizer=keras.optimizers.Adam(lr=5e-3, amsgrad=True),
        loss=keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=["accuracy"])

    callback = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir=tempfile.mkdtemp())
    ]

    history = model_for_pruning.fit(x=genTrain,
                                    validation_data=genValidation,
                                    epochs=args.pruning_epochs,
                                    callbacks=callback,
                                    verbose=args.verbose)

    test_auc_roc(model_for_pruning, genTrain,
                 "Pruning {}% Training".format(sparsity))
    test_auc_roc(model_for_pruning, genTest,
                 "Pruning {}% Test".format(sparsity))
def train(args):
    #physical_devices = tf.config.list_physical_devices("GPU")
    #tf.config.set_visible_devices(physical_devices, "GPU")

    # Now controls host RAM usage too -> optimise for GPU/host RAM and execution speed
    generatorOptions['batchSize'] = args.batchSize

    if args.nEvents: generatorOptions['dataSize'] = args.nEvents

    # Whether to use multiprocessing for parallel data loading
    generatorOptions['useMultiprocessing'] = args.useMultiprocessing

    network = getattr(transformerDefinition, args.network)

    model = network(TRACK_SHAPE, args.numHidden, args.numHeads)
    if args.verbose: model.summary()

    adam = Adam(lr=args.learningRate, amsgrad=True)
    earlyStopping = EarlyStopping(patience=args.patience)

    model.compile(optimizer=adam,
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

    genTrain, genValidation, genTest = createSplitGenerators(
        args.inputFiles,
        generatorOptions,
        shuffle=args.shuffle or args.shuffleChunks,
        shuffleChunks=args.shuffleChunks)

    history = model.fit_generator(generator=genTrain,
                                  validation_data=genValidation,
                                  use_multiprocessing=args.useMultiprocessing,
                                  workers=args.nWorkers,
                                  callbacks=[earlyStopping],
                                  epochs=args.epochs,
                                  verbose=2)

    # Get the tags for the full training sample, so that these can be used to calculate the ROC
    y_train = genTrain.getTags()
    y_test = genTest.getTags()

    # Can use the generators for prediction too, but need to ensure that there is no shuffling wrt the above
    y_out_train = model.predict_generator(genTrain)
    y_out_test = model.predict_generator(genTest)

    rocAUC_train = roc_auc_score(y_train, y_out_train)
    rocAUC_test = roc_auc_score(y_test, y_out_test)

    print(('ROC Train:', rocAUC_train))
    print(('ROC Test:', rocAUC_test))

    print(args.modelName)
    model.summary()
    #makeTrainingPlots(model, plotdir = args.outputDir, modelName = args.modelName)
    #makeTrainingPlotsTF2(history, plotdir = args.outputDir, modelName = args.modelName)
    saveModel(model, args.outputDir + args.modelName)
    exportForCalibration(y_test, y_out_test, args.outputDir)
    return model
def train_model(args):
    model = globals()["{}_seq".format(args.network)](args.hidden_units, args.dropout, args.rnn_dropout)
    model.summary()
    genTrain, genValidation, genTest = createSplitGenerators(args.inputFiles,
                                                                 generatorOptions,
                                                                 shuffle = False,
                                                                 shuffleChunks = False)
    model.compile(optimizer = keras.optimizers.Adam(lr = 5e-3, amsgrad = True),
                  loss = keras.losses.BinaryCrossentropy(from_logits = True), metrics = ["accuracy"])
    callback = keras.callbacks.EarlyStopping(patience = 50)
    history = model.fit(x = genTrain, validation_data = genValidation, epochs = args.training_epochs, callbacks = callback, verbose = args.verbose)
    test_auc_roc(model, genTrain, "Training")
    test_auc_roc(model, genTest, "Test")

    return model
def qat(args, model):
    genTrain, genValidation, genTest = createSplitGenerators(
        args.inputFiles, generatorOptions, shuffle=False, shuffleChunks=False)
    model_2D = convert_model_1D_2D(model, args)

    qat_model = tfmot.quantization.keras.quantize_model(model_2D)
    qat_model.summary()
    qat_model.compile(optimizer=keras.optimizers.Adam(lr=1e-4, amsgrad=True),
                      loss=keras.losses.BinaryCrossentropy(from_logits=True),
                      metrics=["accuracy"])

    qat_history = qat_model.fit(x=genTrain,
                                validation_data=genValidation,
                                epochs=args.quant_epochs,
                                verbose=args.verbose)

    test_auc_roc(qat_model, genTrain, "Quantised Training")
    test_auc_roc(qat_model, genTest, "Quantised Test")
def train(args):

    # Now controls host RAM usage too -> optimise for GPU/host RAM and execution speed
    generatorOptions['batchSize'] = args.batchSize

    generatorOptions['dataSize'] = args.nEvents
    generatorOptions['useWeights'] = args.useWeights

    generatorOptions['trainingType'] = 'category' if not args.flat else 'category_flat'
    generatorOptions['featureName'] = 'featureArray' if not args.flat else 'featureArrayFlat'
    generatorOptions['catName'] = 'catArray' if not args.flat else 'catArrayFlat'

    if not args.flat:
        model = catNetwork(TRACK_SHAPE, nTrackCategories)
    else:
        model = catNetworkFlat(TRACK_SHAPE[1:], nTrackCategories)

    if args.verbose : model.summary()

    adam = Adam(lr = args.learningRate, amsgrad = True)
    earlyStopping = EarlyStopping(patience = args.patience)

    if not args.flat:
        model.compile(optimizer = adam, loss = 'categorical_crossentropy',
                      metrics=['categorical_accuracy'], sample_weight_mode = "temporal")
    else:
        model.compile(optimizer = adam, loss = 'categorical_crossentropy', metrics=['categorical_accuracy'])

    genTrain, genValidation, genTest = createSplitGenerators(args.inputFiles,
                                                             generatorOptions,
                                                             shuffle = args.shuffle or args.shuffleChunks,
                                                             shuffleChunks = args.shuffleChunks)

    model.fit_generator(generator = genTrain,
                        validation_data = genValidation,
                        callbacks = [earlyStopping],
                        epochs = args.epochs, verbose = args.verbose)

    y_train = genTrain.getCats()
    y_test = genTest.getCats()

    # Can use the generators for prediction too, but need to ensure that there is no shuffling wrt the above
    y_out_train = model.predict_generator(genTrain)
    y_out_test = model.predict_generator(genTest)

    y_train = flatten(y_train)
    y_test = flatten(y_test)

    y_out_train = flatten(y_out_train)
    y_out_test = flatten(y_out_test)

    y_train_sparse = np.argmax(y_train, axis = 1)
    y_out_train_sparse = np.argmax(y_out_train, axis = 1)

    evaluateOvRPredictions(y_train, y_out_train, 'Train')
    evaluateOvRPredictions(y_test, y_out_test, 'Test')

    plot_confusion_matrix(y_train_sparse, y_out_train_sparse, np.unique(y_train_sparse), normalize = True)
    plt.savefig('confusion_matrix_' + args.modelName + '.pdf')
    plt.clf()

    makeTrainingPlots(model, accName = 'categorical_accuracy', modelName = args.modelName)
    saveModel(model, args.modelName)