def build_model(data_tensor,
                labels,
                reuse,
                training,
                output_shape,
                perturb_norm=False,
                data_format='NHWC'):
    """Create the hgru from Learning long-range..."""
    if isinstance(output_shape, list):
        output_shape = output_shape[-1]
    elif isinstance(output_shape, dict):
        output_shape = output_shape['output']
    # norm_moments_training = training  # Force instance norm
    # normalization_type = 'no_param_batch_norm_original'
    normalization_type = 'no_param_instance_norm'
    # output_normalization_type = 'batch_norm_original_renorm'
    output_normalization_type = 'instance_norm'
    data_tensor, long_data_format = tf_fun.interpret_data_format(
        data_tensor=data_tensor, data_format=data_format)

    # Prepare gammanet structure
    (compression, ff_kernels, ff_repeats, features, fgru_kernels,
     additional_readouts) = v2_big_working()
    gammanet_constructor = tf_fun.get_gammanet_constructor(
        compression=compression,
        ff_kernels=ff_kernels,
        ff_repeats=ff_repeats,
        features=features,
        fgru_kernels=fgru_kernels)
    aux = get_aux()

    # Build model
    with tf.variable_scope('vgg', reuse=reuse):
        aux = get_aux()
        # moments_file = "../undo_bias/neural_models/linear_moments/INSILICO_BSDS_vgg_gratings_simple_tb_feature_matrix.npz"
        # model_file = "../undo_bias/neural_models/linear_models/INSILICO_BSDS_vgg_gratings_simple_tb_model.joblib.npy"
        fb_moments_file = "../undo_bias/neural_models/linear_moments/tb_feature_matrix.npz"
        fb_model_file = "../undo_bias/neural_models/linear_models/tb_model.joblib.npy"
        ff_moments_file = "../undo_bias/neural_models/linear_moments/conv2_2_tb_feature_matrix.npz"
        ff_model_file = "../undo_bias/neural_models/linear_models/conv2_2_tb_model.joblib.npy"
        vgg = vgg16.Vgg16(
            vgg16_npy_path=
            '/media/data_cifs_lrs/clicktionary/pretrained_weights/vgg16.npy',
            reuse=reuse,
            aux=aux,
            moments_file=ff_moments_file,  # Perturb FF drive
            model_file=ff_model_file,
            train=False,
            timesteps=8,
            perturb=0.95,  # 2.,  # 1.001,  # 17.1,
            # perturb=.0000001,  # 2.,  # 1.001,  # 17.1,
            # perturb=.05,  # 2.,  # 1.001,  # 17.1
            perturb_norm=perturb_norm,
            # perturb=1.5,  # 2.,  # 1.001,  # 17.1,
            # perturb=2.,  # 2.,  # 1.001,  # 17.1,
            fgru_normalization_type=normalization_type,
            ff_normalization_type=normalization_type)
        # gn = tf.get_default_graph()
        # with gn.gradient_override_map({'Conv2D': 'PerturbVizGrad'}):
        # Scope the entire GN (with train=False).
        # The lowest-level recurrent tensor is the only
        # trainable tensor. This grad op will freeze the
        # center unit, forcing other units to overcompensate
        # to recreate a model's prediction.
        # TODO: need to get the original model output... could precompute this.  # noqa
        vgg(rgb=data_tensor, label=labels, constructor=gammanet_constructor)
        activity = vgg.fgru_0

        # Load tuning curve transform for fb output
        moments = np.load(fb_moments_file)
        means = moments["means"]
        stds = moments["stds"]
        clf = np.load(fb_model_file).astype(np.float32)

        # Transform activity to outputs
        bs, h, w, _ = vgg.fgru_0.get_shape().as_list()
        hh, hw = h // 2, w // 2
        sel_units = tf.reshape(vgg.fgru_0[:, hh - 2:hh + 2, hw - 2:hw + 2, :],
                               [bs, -1])

        if perturb_norm:
            # Normalize activities -- Not normalized!!
            sel_units = (sel_units - means) / stds

        # Map responses
        # inv_clf = np.linalg.inv(clf.T.dot(clf)).astype(np.float32)  # Precompute inversion
        # inv_clf = tf.linalg.inv(tf.matmul(clf, clf, transpose_a=True))
        inv_clf = np.linalg.inv(clf.T.dot(clf)).astype(
            np.float32)  # Precompute inversion
        activity = tf.matmul(tf.matmul(inv_clf, clf, transpose_b=True),
                             sel_units,
                             transpose_b=True)

    # bg = tf.reduce_mean(vgg.conv2_2 ** 2, reduction_indices=[-1], keep_dims=True)
    # bg = tf.cast(tf.greater(bg, tf.reduce_mean(bg)), tf.float32)
    # bg_dil = dilation2d(img=bg, extent=5)
    # extra_activities = {"mask": bg, "mask_dil": bg_dil}  # {"mask": tf.reduce_mean(vgg.conv2_2 ** 2, reduction_indices=[-1])}  # tf.get_variable(name="perturb_viz")}  # idx: v for idx, v in enumerate(hs_0)}
    # extra_activities = {"fgru": vgg.fgru_0, "penalty": tf.constant(0.), "conv": vgg.error_1}  # tf.get_variable(name="perturb_viz")}  # idx: v for idx, v in enumerate(hs_0)}
    # extra_activities = {"fgru": vgg.fgru_0, "penalty": tf.constant(0.), "conv": vgg.error_1}  # tf.get_variable(name="perturb_viz")}  # idx: v for idx, v in enumerate(hs_0)}
    extra_activities = {}
    if activity.dtype != tf.float32:
        activity = tf.cast(activity, tf.float32)
    # return [activity, h_deep], extra_activities
    return activity, extra_activities
def build_model(data_tensor,
                labels,
                reuse,
                training,
                output_shape,
                perturb_norm=False,
                data_format='NHWC'):
    """Create the hgru from Learning long-range..."""
    if isinstance(output_shape, list):
        output_shape = output_shape[-1]
    elif isinstance(output_shape, dict):
        output_shape = output_shape['output']
    # norm_moments_training = training  # Force instance norm
    # normalization_type = 'no_param_batch_norm_original'
    normalization_type = 'no_param_instance_norm'
    # output_normalization_type = 'batch_norm_original_renorm'
    output_normalization_type = 'instance_norm'
    data_tensor, long_data_format = tf_fun.interpret_data_format(
        data_tensor=data_tensor, data_format=data_format)

    # Prepare gammanet structure
    (compression, ff_kernels, ff_repeats, features, fgru_kernels,
     additional_readouts) = v2_big_working()
    gammanet_constructor = tf_fun.get_gammanet_constructor(
        compression=compression,
        ff_kernels=ff_kernels,
        ff_repeats=ff_repeats,
        features=features,
        fgru_kernels=fgru_kernels)
    aux = get_aux()

    # Build model
    with tf.variable_scope('vgg', reuse=reuse):
        hh, hw = 73, 180
        aux = get_aux()
        fb_moments_file = "../undo_bias/neural_models/linear_moments/tb_feature_matrix.npz"
        fb_model_file = "../undo_bias/neural_models/linear_models/tb_model.joblib.npy"
        ff_moments_file = "../undo_bias/neural_models/linear_moments/conv2_2_tb_feature_matrix.npz"
        ff_model_file = "../undo_bias/neural_models/linear_models/conv2_2_tb_model.joblib.npy"
        vgg = vgg16.Vgg16(
            vgg16_npy_path=
            '/media/data_cifs_lrs/clicktionary/pretrained_weights/vgg16.npy',
            reuse=reuse,
            aux=aux,
            hh=hh,
            hw=hw,
            moments_file=ff_moments_file,
            model_file=ff_model_file,
            train=False,
            timesteps=8,
            perturb=60.0,  # 70 might work  # 2.,  # 1.001,  # 17.1,
            perturb_norm=perturb_norm,
            # perturb=1.5,  # 2.,  # 1.001,  # 17.1,
            # perturb=2.,  # 2.,  # 1.001,  # 17.1,
            fgru_normalization_type=normalization_type,
            ff_normalization_type=normalization_type)
        # gn = tf.get_default_graph()
        # with gn.gradient_override_map({'Conv2D': 'PerturbVizGrad'}):
        # Scope the entire GN (with train=False).
        # The lowest-level recurrent tensor is the only
        # trainable tensor. This grad op will freeze the
        # center unit, forcing other units to overcompensate
        # to recreate a model's prediction.
        # TODO: need to get the original model output... could precompute this.  # noqa
        vgg(rgb=data_tensor, label=labels, constructor=gammanet_constructor)
        activity = vgg.fgru_0

        # Load tuning curve transform
        moments = np.load(fb_moments_file)
        means = moments["means"]
        stds = moments["stds"]
        clf = np.load(fb_model_file).astype(np.float32)

        # Transform activity to outputs
        bs, h, w, _ = vgg.fgru_0.get_shape().as_list()
        sel_units = tf.reshape(vgg.fgru_0[:, hh - 2:hh + 2, hw - 2:hw + 2, :],
                               [bs, -1])
        # grad0 = tf.gradients(sel_units, vgg.conv2_2)[0]

        if perturb_norm:
            # Normalize activities -- Not normalized!!
            sel_units = (sel_units - means) / stds

        # Map responses
        # inv_clf = np.linalg.inv(clf.T.dot(clf)).astype(np.float32)  # Precompute inversion
        # inv_clf = tf.linalg.inv(tf.matmul(clf, clf, transpose_a=True))
        inv_clf = np.linalg.inv(clf.T.dot(clf)).astype(
            np.float32)  # Precompute inversion
        activity = tf.matmul(tf.matmul(inv_clf, clf, transpose_b=True),
                             sel_units,
                             transpose_b=True)

    impatch = data_tensor[:, hh * 2 - 14:hh * 2 + 18, hw * 2 - 14:hw * 2 + 18]
    mx = max(h, w)
    x, y = np.meshgrid(np.arange(mx), np.arange(mx))
    x = x[:h, :w]
    y = y[:h, :w]
    xy = np.stack((x, y), -1)
    coord = np.asarray([hh, hw])[None]
    dist = np.sqrt(((xy - coord)**2).mean(-1))
    dist = dist / dist.max()
    # penalty = tf.reduce_mean(tf.tanh(tf.abs(tf.squeeze(vgg.fgru_0))), -1) * dist
    # dist = dist * tf.squeeze(vgg.mask)  # zero out units in the RF
    # penalty = tf.cast(tf.sigmoid(tf.reduce_mean(tf.abs(tf.squeeze(vgg.fgru_0)), -1)) - 0.5, tf.float32) * dist.astype(np.float32)
    penalty = tf.cast(
        tf.sigmoid(tf.reduce_mean(tf.abs(tf.squeeze(vgg.mult)), -1)) - 0.5,
        tf.float32) * dist.astype(np.float32)
    penalty = tf.cast(penalty, tf.float32) * tf.cast(
        1. - tf.squeeze(vgg.mask), tf.float32)  # 0 values in the H-diameter
    penalty = penalty * 0.0
    extra_activities = {
        "fgru": vgg.fgru_0,
        "mask": vgg.mask,
        "penalty": penalty,
        "impatch": impatch
    }  # tf.get_variable(name="perturb_viz")}  # idx: v for idx, v in enumerate(hs_0)}
    if activity.dtype != tf.float32:
        activity = tf.cast(activity, tf.float32)
    # return [activity, h_deep], extra_activities
    return activity, extra_activities