示例#1
0
def trainDROModel(dro_type, epochs, steps_adv, budget, activation, batch_size, loss_criterion, cost_function=None):
    """
    Train a neural network using one of the following DRO methods:
        - PGD
        - Lagrangian relaxation based method developed by Sinha et al. 
            This is also called WRM.
        - the Frank-Wolfe method based approach developed by Staib et al. 
    """

    model = MNISTClassifier(activation=activation)
    if dro_type == 'PGD':
        train_module = ProjetcedDRO(model, loss_criterion)
    elif dro_type == 'Lag':
        assert cost_function is not None
        train_module = LagrangianDRO(model, loss_criterion, cost_function)
    elif dro_type == 'FW':
        train_module = FrankWolfeDRO(model, loss_criterion, p=2, q=2)
    else:
        raise ValueError("The type of DRO is not valid.")

    train_module.train(budget=budget, batch_size=batch_size,
                       epochs=epochs, steps_adv=steps_adv)
    folderpath = "./DRO_models/"
    filepath = folderpath + \
        "{}_DRO_activation={}_epsilon={}.pt".format(
            dro_type, activation, budget)
    torch.save(model.state_dict(), filepath)
    print("A neural network adversarially trained using {} is now saved at {}.".format(
        dro_type, filepath))
示例#2
0
def trainModelLoss(dro_type,
                   epochs,
                   steps_adv,
                   budget,
                   activation,
                   batch_size,
                   loss_criterion,
                   cost_function=None):
    """
    Train a neural network with a specified loss function.
    """

    model = MNISTClassifier(activation=activation)
    if dro_type == 'PGD':
        train_module = ProjetcedDRO(model, loss_criterion)
    elif dro_type == 'Lag':
        assert cost_function is not None
        train_module = LagrangianDRO(model, loss_criterion, cost_function)
    elif dro_type == 'FW':
        train_module = FrankWolfeDRO(model, loss_criterion, p=2, q=2)
    else:
        raise ValueError("The type of DRO is not valid.")

    train_module.train(budget=budget,
                       batch_size=batch_size,
                       epochs=epochs,
                       steps_adv=steps_adv)
    folderpath = "./Loss_models/"
    filepath = folderpath + "{}_DRO_activation={}_epsilon={}_loss={}.pt".format(
        dro_type, activation, budget, loss_criterion.__name__)
    torch.save(model.state_dict(), filepath)
    print("A neural network adversarially trained using {} now saved at: {}".
          format(dro_type, filepath))