示例#1
0
def deep_linear_model_loss_regularized_dw0_lp_da1_lq(w, x, y, reg_coeff,
                                                     norm_type1, norm_type2):
    """Penalize by ||dL/dw0||+||dL/da1||, where w0, a1 are first layer w, o."""
    dloss_da = jax.grad(deep_linear_model_loss, argnums=1)
    loss = deep_linear_model_loss(w, x, y)
    a1 = w[0].T @ x
    reg = norm_f(w[0], norm_type1) + norm_f(dloss_da(w[1:], a1, y), norm_type2)
    return loss + reg_coeff * reg
示例#2
0
def compute_min_norm_solution(x, y, norm_type):
    """Compute the min-norm solution using a convex-program solver."""
    w = cp.Variable((x.shape[0], 1))
    if norm_type == 'linf':
        # compute minimal L_infinity solution
        constraints = [cp.multiply(y, (w.T @ x)) >= 1]
        prob = cp.Problem(cp.Minimize(cp.norm_inf(w)), constraints)
    elif norm_type == 'l2':
        # compute minimal L_2 solution
        constraints = [cp.multiply(y, (w.T @ x)) >= 1]
        prob = cp.Problem(cp.Minimize(cp.norm2(w)), constraints)
    elif norm_type == 'l1':
        # compute minimal L_1 solution
        constraints = [cp.multiply(y, (w.T @ x)) >= 1]
        prob = cp.Problem(cp.Minimize(cp.norm1(w)), constraints)
    elif norm_type[0] == 'l':
        # compute minimal Lp solution
        p = float(norm_type[1:])
        constraints = [cp.multiply(y, (w.T @ x)) >= 1]
        prob = cp.Problem(cp.Minimize(cp.pnorm(w, p)), constraints)
    elif norm_type == 'dft1':
        w = cp.Variable((x.shape[0], 1), complex=True)
        # compute minimal Fourier L1 norm (||F(w)||_1) solution
        dft = scipy.linalg.dft(x.shape[0]) / np.sqrt(x.shape[0])
        constraints = [cp.multiply(y, (cp.real(w).T @ x)) >= 1]
        prob = cp.Problem(cp.Minimize(cp.norm1(dft @ w)), constraints)
    prob.solve(verbose=True)
    logging.info('Min %s-norm solution found (norm=%.4f)', norm_type,
                 float(norm_f(w.value, norm_type)))
    return cp.real(w).value
示例#3
0
def two_linear_model_loss_regularized_dy0dx_lp(w0, w1, x, y, reg_coeff,
                                               norm_type):
    """Penalize by ||dy0/dx||, optimal when first layer is fixed."""
    dy0_dx_f = jax.grad(two_linear_model_y0_mean, argnums=1)
    dy0_dx = dy0_dx_f(w0, x)
    loss = two_linear_model_loss(w0, w1, x, y)
    reg = norm_f(dy0_dx, norm_type)
    return loss + reg_coeff * reg
示例#4
0
def deep_linear_model_normalize_param(params, norm_type):
    """Normalizes the last layer weights by the norm of the product of weights."""
    norm_p = norm_f(deep_linear_model_linearize_param(params), norm_type)
    params_new = params
    params_new[-1] = params[-1] / jnp.maximum(1e-7, norm_p)
    return params_new
示例#5
0
def linear_model_loss_regularized_dw_lp(w, x, y, reg_coeff, norm_type):
    """Penalize by ||dL/dx||."""
    loss_and_dloss_dw_f = jax.value_and_grad(linear_model_loss, argnums=0)
    loss, dloss_dw = loss_and_dloss_dw_f(w, x, y)
    reg = norm_f(dloss_dw, norm_type)
    return loss + reg_coeff * reg
示例#6
0
def linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type):
    loss = linear_model_loss(w, x, y)
    reg = norm_f(w, norm_type)
    return loss + reg_coeff * reg
示例#7
0
def get_model_functions(rng_key, dim, arch, nlayers, regularizer, reg_coeff,
                        r):
    """Returns model init/predict/loss functions given the model name."""
    loss_and_prox_op = None
    if arch == 'linear':
        init_f, predict_f = linear_model_init_param, linear_model_predict
        if regularizer == 'none':
            loss_f = linear_model_loss
            prox_op = lambda x, _: x
        elif re.match('w_l.*', regularizer) or re.match(
                'w_dft.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = loss_f_with_args(linear_model_loss_regularized_w_lp,
                                      reg_coeff, norm_type)
            prox_op = lambda v, lam: get_prox_op(norm_type)(v, lam * reg_coeff)
        elif re.match('dx_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = loss_f_with_args(linear_model_loss_regularized_dx_lp,
                                      reg_coeff, norm_type)
        elif re.match('dw_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = loss_f_with_args(linear_model_loss_regularized_dw_lp,
                                      reg_coeff, norm_type)
        model_param = init_f(dim)
        loss_adv_f = linear_model_loss
        linearize_param = lambda p: p
        normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(p, n))
        loss_and_prox_op = (linear_model_loss, prox_op)
    elif arch == 'deep_linear':
        init_f = deep_linear_model_init_param
        predict_f = deep_linear_model_predict
        if regularizer == 'none':
            loss_f = deep_linear_model_loss
        elif re.match('w_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = loss_f_with_args(deep_linear_model_loss_regularized_w_lp,
                                      reg_coeff, norm_type)
        elif re.match('dx_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dx_lp,
                                      reg_coeff, norm_type)
        elif re.match('dw_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = loss_f_with_args(deep_linear_model_loss_regularized_dw_lp,
                                      reg_coeff, norm_type)
        elif re.match('dw0_l.*_da1_l.*', regularizer):
            norm_type = regularizer.split('_')
            norm_type1, norm_type2 = norm_type[1], norm_type[3]
            loss_f = loss_f_with_args(
                deep_linear_model_loss_regularized_dw0_lp_da1_lq, reg_coeff,
                norm_type1, norm_type2)
        model_param = init_f(dim, nlayers, r, rng_key)
        loss_adv_f = deep_linear_model_loss
        linearize_param = deep_linear_model_linearize_param
        normalize_param = deep_linear_model_normalize_param
    elif arch == 'two_linear_fixed_w0':
        w0, model_param = two_linear_w0fixed_init_param(dim, r, rng_key)
        predict_f = lambda w, x: two_linear_model_predict(w0, w, x)
        if regularizer == 'none':
            loss_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y)
        elif re.match('dy1dx_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp(  # pylint: disable=g-long-lambda
                w0, w, x, y, reg_coeff, norm_type)
        elif re.match('w0w1_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp(  # pylint: disable=g-long-lambda
                w0, w, x, y, reg_coeff, norm_type)
        elif re.match('w1_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_w1_lp(  # pylint: disable=g-long-lambda
                w0, w, x, y, reg_coeff, norm_type)
        loss_adv_f = lambda w, x, y: two_linear_model_loss(w0, w, x, y)
        linearize_param = lambda p: w0.T @ p
        normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(
            w0.T @ p, n))
    elif arch == 'two_linear_fixed_w1' or arch == 'two_linear_fixed_w1_noniso':
        non_isotropic = arch == 'two_linear_fixed_w1_noniso'
        model_param, w1 = two_linear_w1fixed_init_param(
            dim, r, rng_key, non_isotropic)
        predict_f = lambda w, x: two_linear_model_predict(w, w1, x)
        if regularizer == 'none':
            loss_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y)
        elif re.match('dy0dx_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy0dx_lp(  # pylint: disable=g-long-lambda
                w, w1, x, y, reg_coeff, norm_type)
        elif re.match('w0w1_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0w1_lp(  # pylint: disable=g-long-lambda
                w, w1, x, y, reg_coeff, norm_type)
        elif re.match('w0_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_w0_lp(  # pylint: disable=g-long-lambda
                w, w1, x, y, reg_coeff, norm_type)
        elif re.match('dy1dx_l.*', regularizer):
            norm_type = regularizer.split('_')[1]
            loss_f = lambda w, x, y: two_linear_model_loss_regularized_dy1dx_lp(  # pylint: disable=g-long-lambda
                w, w1, x, y, reg_coeff, norm_type)
        loss_adv_f = lambda w, x, y: two_linear_model_loss(w, w1, x, y)
        linearize_param = lambda p: p.T @ w1
        normalize_param = lambda p, n: p / jnp.maximum(1e-7, norm_f(
            p.T @ w1, n))
    elif arch == 'conv_linear':
        init_f = conv_linear_model_init_param
        predict_f = conv_linear_model_predict
        if regularizer == 'none':
            loss_f = conv_linear_model_loss
            prox_op = lambda x, _: x
        model_param = init_f(dim, nlayers, r, rng_key)
        loss_adv_f = conv_linear_model_loss
        linearize_param = conv_linear_model_linearize_param
        normalize_param = conv_linear_model_normalize_param
        loss_and_prox_op = (conv_linear_model_loss, prox_op)

    return (model_param, predict_f, loss_f, loss_adv_f, linearize_param,
            normalize_param, loss_and_prox_op)
示例#8
0
def two_linear_model_loss_regularized_w1_lp(w0, w1, x, y, reg_coeff,
                                            norm_type):
    """Penalize by ||w1||, optimal when first layer is fixed faster than dydx."""
    loss = two_linear_model_loss(w0, w1, x, y)
    reg = norm_f(w1, norm_type)
    return loss + reg_coeff * reg
示例#9
0
def deep_linear_model_loss_regularized_w_lp(w, x, y, reg_coeff, norm_type):
    """Penalize by ||dL/dw||."""
    loss = deep_linear_model_loss(w, x, y)
    w_unravel = ravel_pytree(w)[0]
    reg = norm_f(w_unravel, norm_type)
    return loss + reg_coeff * reg
示例#10
0
def train(model_param, train_test_data, predict_f, loss_f, loss_adv_f,
          linearize_f, normalize_f, loss_and_prox_op, summary, config,
          rng_key):
    """Train a model and log risks."""
    dloss_dw = jax.grad(loss_f, argnums=0)
    dloss_adv_dx = jax.grad(loss_adv_f, argnums=1)

    train_data = train_test_data[0]
    xtrain, ytrain = train_data

    # Precompute min-norm solutions
    if config.enable_cvxpy:
        min_norm_w = {}
        for norm_type in config.available_norm_types:
            min_norm_w[norm_type] = compute_min_norm_solution(
                xtrain, ytrain, norm_type)
        if config.adv.eps_from_cvxpy:
            dual_norm = norm_type_dual(config.adv.norm_type)
            wcomp = min_norm_w[dual_norm]
            wnorm = norm_f(wcomp, dual_norm)
            margin = 1. / wnorm
            config.adv.eps_tot = config.adv.eps_iter = float(2 * margin)

    if config['optim']['name'] == 'cvxpy':
        norm_type = config['optim']['norm']
        cvxpy_sol = compute_min_norm_solution(xtrain, ytrain, norm_type)
        model_param = jnp.array(cvxpy_sol)

    # Train loop
    optim_step, optim_options = optim.get_optimizer_step(config['optim'])
    niters = optim_options['niters']
    for step in range(1, niters):
        # Take one optimization step
        if config['optim']['name'] != 'cvxpy':
            if config['optim']['adv_train']['enable']:
                # Adversarial training
                rng_key, rng_subkey = jax.random.split(rng_key)
                x_adv = adversarial.find_adversarial_samples(
                    train_data, dloss_adv_dx, model_param, normalize_f,
                    config.optim.adv_train, rng_key)
                train_data_new = x_adv, ytrain
            else:
                # Standard training
                train_data_new = train_data
            if config['optim']['name'] == 'fista':
                model_param, optim_options = optim_step(
                    train_data_new, loss_and_prox_op, model_param,
                    optim_options)
            else:
                model_param, optim_options = optim_step(
                    train_data_new, loss_f, model_param, optim_options)

        # Log risks and other statistics
        if (step + 1) % config.log_interval == 0:
            # Evaluate risk on train/test sets
            for do_train in [True, False]:
                data = train_test_data[0] if do_train else train_test_data[1]
                prefix = 'risk/train' if do_train else 'risk/test'
                risk = evaluate_risks(data, predict_f, loss_f, model_param)
                for rname, rvalue in risk.items():
                    summary.scalar('%s/%s' % (prefix, rname),
                                   rvalue,
                                   step=step)
                rng_key, rng_subkey = jax.random.split(rng_key)
                risk = evaluate_adversarial_risk(data, predict_f, dloss_adv_dx,
                                                 model_param, normalize_f,
                                                 config, rng_subkey)
                for rname, rvalue in risk.items():
                    summary.array('%s/%s' % (prefix, rname), rvalue, step=step)

            grad = dloss_dw(model_param, xtrain, ytrain)
            grad_ravel, _ = ravel_pytree(grad)
            model_param_ravel, _ = ravel_pytree(model_param)
            for norm_type in config.available_norm_types:
                # Log the norm of the gradient w.r.t. various norms
                if not norm_type.startswith('dft'):
                    summary.scalar('grad/norm/' + norm_type,
                                   norm_f(grad_ravel, norm_type),
                                   step=step)

                # Log weight norm
                if not norm_type.startswith('dft'):
                    wnorm = norm_f(model_param_ravel, norm_type)
                    summary.scalar('weight/norm/' + norm_type,
                                   wnorm,
                                   step=step)

                # Log margin for the equivalent linearized single layer model
                linear_param = linearize_f(model_param)
                min_loss = jnp.min(ytrain * (linear_param.T @ xtrain))
                wcomp = linear_param / min_loss
                wnorm = norm_f(wcomp, norm_type)
                margin = jnp.sign(min_loss) * 1 / wnorm
                summary.scalar('margin/' + norm_type, margin, step=step)
                summary.scalar('weight/linear/norm/' + norm_type,
                               wnorm,
                               step=step)

                # Cosine similarity between the current params and min-norm solution
                if config.enable_cvxpy:

                    def cos_sim(a, b):
                        return jnp.dot(
                            a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))

                    min_norm_w_ravel, _ = ravel_pytree(min_norm_w[norm_type])
                    cs = cos_sim(linear_param.flatten(), min_norm_w_ravel)
                    summary.scalar('csim_to_wmin/' + norm_type, cs, step=step)

            if 'step_size' in optim_options:
                summary.scalar('optim/step_size',
                               optim_options['step_size'],
                               step=step)

            logging.info('Epoch: [%d/%d]\t%s', step + 1, niters,
                         summary.last_scalars_to_str(config.log_keys))
            logging.flush()

    summary.flush()