Beispiel #1
0
    def __init__(self,
                 X,
                 timesteps=1,
                 SRF=1,
                 SSN=9,
                 SSF=29,
                 strides=[1, 1, 1, 1],
                 padding='SAME',
                 aux=None):
        """Global initializations and settings."""
        self.X = X
        self.n, self.h, self.w, self.k = [int(x) for x in X.get_shape()]
        self.timesteps = timesteps
        self.strides = strides
        self.padding = padding

        # Sort through and assign the auxilliary variables
        aux_vars = auxilliary_variables()
        if aux is not None and isinstance(aux, dict):
            for k, v in aux.iteritems():
                aux_vars[k] = v
        self.update_params(aux_vars)

        # Kernel shapes
        self.SRF, self.SSN, self.SSF = SRF, SSN, SSF
        self.SSN_ext = 2 * py_utils.ifloor(SSN / 2.0) + 1
        self.SSF_ext = 2 * py_utils.ifloor(SSF / 2.0) + 1
        if self.SSN is None:
            self.SSN = self.SRF * 3
        if self.SSF is None:
            self.SSF = self.SRF * 5
        if self.separable:
            self.q_shape = [self.SRF, self.SRF, 1, 1]
            self.u_shape = [self.SRF, self.SRF, 1, 1]
            self.p_shape = [self.SSN_ext, self.SSN_ext, 1, 1]
            self.t_shape = [self.SSF_ext, self.SSF_ext, 1, 1]
        else:
            self.q_shape = [self.SRF, self.SRF, self.k, self.k]
            self.u_shape = [self.SRF, self.SRF, self.k, 1]
            self.p_shape = [self.SSN_ext, self.SSN_ext, self.k, self.k]
            self.t_shape = [self.SSF_ext, self.SSF_ext, self.k, self.k]
        self.i_shape = [self.gate_filter, self.gate_filter, self.k, self.k]
        self.o_shape = [self.gate_filter, self.gate_filter, self.k, self.k]
        self.bias_shape = [1, 1, 1, self.k]

        self.tuning_params = ['Q', 'P', 'T']  # Learned connectivity
        if self.association_field:
            self.p_shape = [self.SSF_ext, self.SSF_ext, self.k, self.k]  # T
            self.tuning_params.pop(self.tuning_params.index('P'))
        else:
            self.tuning_params = ['Q', 'P', 'T']

        self.tuning_shape = [1, 1, self.k, self.k]

        # Nonlinearities and initializations
        self.u_nl = tf.identity
        self.t_nl = tf.identity
        self.q_nl = tf.identity
        self.p_nl = tf.identity
        self.tuning_nl = interpret_nl(self.tuning_nl)
    def __init__(self,
                 X,
                 model_version='full',
                 timesteps=1,
                 lesions=None,
                 SRF=1,
                 SSN=9,
                 SSF=29,
                 strides=[1, 1, 1, 1],
                 padding='SAME',
                 dropout=0.5,
                 dtype=tf.float32,
                 train=True,
                 return_weights=True):

        self.X = X
        self.train = train
        self.dropout = dropout
        self.n, self.h, self.w, self.k = [int(x) for x in X.get_shape()]
        self.model_version = model_version
        self.timesteps = timesteps
        self.lesions = lesions
        self.strides = strides
        self.padding = padding
        self.dtype = dtype
        self.SRF, self.SSN, self.SSF = SRF, SSN, SSF

        self.SSN_ext = 2 * py_utils.ifloor(SSN / 2.0) + 1
        self.SSF_ext = 2 * py_utils.ifloor(SSF / 2.0) + 1
        self.q_shape = [self.SRF, self.SRF, self.k, self.k]
        self.u_shape = [self.SRF, self.SRF, self.k, 1]
        self.p_shape = [self.SSN_ext, self.SSN_ext, self.k, self.k]
        self.t_shape = [self.SSF_ext, self.SSF_ext, self.k, self.k]
        self.i_shape = [self.SRF, self.SRF, self.k, self.k]
        self.o_shape = [self.SRF, self.SRF, self.k, self.k]
        self.u_nl = tf.identity
        self.t_nl = tf.identity
        self.q_nl = tf.identity
        self.p_nl = tf.identity
        self.tuning_nl = tf.nn.relu
        self.tuning_shape = [1, 1, self.k, self.k]
        self.tuning_params = ['Q', 'P', 'T']  # Learned connectivity
        self.recurrent_nl = tf.nn.relu
        self.gate_nl = tf.nn.sigmoid

        self.return_weights = return_weights
        self.normal_initializer = False
        if self.SSN is None:
            self.SSN = self.SRF * 3
        if self.SSF is None:
            self.SSF = self.SRF * 5
Beispiel #3
0
    def __init__(
            self,
            X,
            model_version='full_with_cell_states',
            timesteps=1,
            lesions=None,
            SRF=1,
            SSN=9,
            SSF=29,
            strides=[1, 1, 1, 1],
            padding='SAME',
            dtype=tf.float32,
            return_weights=True):

        self.X = X
        self.n, self.h, self.w, self.k = [int(x) for x in X.get_shape()]
        self.model_version = model_version
        self.timesteps = timesteps
        self.lesions = lesions
        self.strides = strides
        self.padding = padding
        self.dtype = dtype
        self.SRF, self.SSN, self.SSF = SRF, SSN, SSF

        self.SSN_ext = 2 * py_utils.ifloor(SSN / 2.0) + 1
        self.SSF_ext = 2 * py_utils.ifloor(SSF / 2.0) + 1
        self.q_shape = [self.SRF, self.SRF, self.k, self.k]
        self.u_shape = [self.SRF, self.SRF, self.k, 1]
        self.p_shape = [self.SSN_ext, self.SSN_ext, self.k, self.k]
        self.t_shape = [self.SSF_ext, self.SSF_ext, self.k, self.k]
        self.i_shape = self.q_shape
        self.o_shape = self.q_shape
        self.u_nl = tf.identity
        self.t_nl = tf.identity
        self.q_nl = tf.identity
        self.p_nl = tf.identity
        self.i_nl = tf.nn.relu  # input non linearity
        self.o_nl = tf.nn.relu  # output non linearity

        self.return_weights = return_weights
        self.normal_initializer = False
        if self.SSN is None:
            self.SSN = self.SRF * 3
        if self.SSF is None:
            self.SSF = self.SRF * 5
    def __init__(self,
                 X,
                 timesteps=1,
                 SRF=1,
                 SSN=9,
                 SSF=29,
                 strides=[1, 1, 1, 1],
                 padding='SAME',
                 aux=None,
                 train=True):
        """Global initializations and settings."""
        self.X = X
        self.n, self.h, self.w, self.k = [int(x) for x in X.get_shape()]
        self.timesteps = timesteps
        self.strides = strides
        self.padding = padding
        self.train = train

        # Sort through and assign the auxilliary variables
        aux_vars = auxilliary_variables()
        if aux is not None and isinstance(aux, dict):
            for k, v in aux.iteritems():
                aux_vars[k] = v
        self.update_params(aux_vars)

        # Kernel shapes
        self.SRF, self.SSN, self.SSF = SRF, SSN, SSF

        # if isinstance(SSN, list):
        #     self.SSN_ext = [2 * py_utils.ifloor(x / 2.0) + 1 for x in SSN]
        # else:
        #     self.SSN_ext = 2 * py_utils.ifloor(SSN / 2.0) + 1
        if isinstance(SSF, list):
            self.SSF_ext = [2 * py_utils.ifloor(x / 2.0) + 1 for x in SSF]
        else:
            self.SSF_ext = 2 * py_utils.ifloor(SSF / 2.0) + 1
        if self.SSN is None:
            self.SSN = self.SRF * 3
        if self.SSF is None:
            self.SSF = self.SRF * 5

        # if self.separable:
        #     self.q_shape = [self.SRF, self.SRF, 1, 1]
        #     self.u_shape = [self.SRF, self.SRF, 1, 1]
        #     self.p_shape = [self.SSF_ext, self.SSF_ext, 1, 1]
        self.q_shape = [self.SRF, self.SRF, self.k, self.k]
        self.u_shape = [self.SRF, self.SRF, self.k, 1]
        if isinstance(SSF, list):
            self.p_shape = [[ssf_ext, ssf_ext, self.k, self.k]
                            for ssf_ext in self.SSF_ext]
        else:
            self.p_shape = [self.SSF_ext, self.SSF_ext, self.k, self.k]
        self.i_shape = [self.gate_filter, self.gate_filter, self.k, self.k]
        self.o_shape = [self.gate_filter, self.gate_filter, self.k, self.k]
        self.bias_shape = [1, 1, 1, self.k]
        self.tuning_params = ['Q', 'P']  # Learned connectivity
        self.tuning_shape = [1, 1, self.k, self.k]

        # Nonlinearities and initializations
        self.u_nl = tf.identity
        self.q_nl = tf.identity
        self.p_nl = tf.identity

        # Set integration operations
        self.ii, self.oi = self.interpret_integration(self.integration_type)
    def prepare_tensors(self):
        """ Prepare recurrent/forward weight matrices."""
        self.weight_dict = {  # Weights lower/activity upper
            'U': {
                'r': {
                    'weight': 'u_r',
                    'activity': 'U_r'
                }
            },
            'P': {
                'r': {
                    'weight': 'p_r',
                    'activity': 'P_r',
                    'tuning': 'p_t'
                }
            },
            'Q': {
                'r': {
                    'weight': 'q_r',
                    'activity': 'Q_r',
                    'tuning': 'q_t'
                }
            },
            'I': {
                'r': {  # Recurrent state
                    'weight': 'i_r',
                    'bias': 'i_b',
                    'activity': 'I_r'
                },
                'f': {  # Recurrent state
                    'weight': 'i_f',
                    'activity': 'I_f'
                },
            },
            'O': {
                'r': {  # Recurrent state
                    'weight': 'o_r',
                    'bias': 'o_b',
                    'activity': 'O_r'
                },
                'f': {  # Recurrent state
                    'weight': 'o_f',
                    'activity': 'O_f'
                },
            },
            'xi': {
                'r': {  # Recurrent state
                    'weight': 'xi',
                }
            },
            'alpha': {
                'r': {  # Recurrent state
                    'weight': 'alpha',
                }
            },
            'beta': {
                'r': {  # Recurrent state
                    'weight': 'beta',
                }
            },
            'mu': {
                'r': {  # Recurrent state
                    'weight': 'mu',
                }
            },
            'nu': {
                'r': {  # Recurrent state
                    'weight': 'nu',
                }
            },
            'zeta': {
                'r': {  # Recurrent state
                    'weight': 'zeta',
                }
            },
            'gamma': {
                'r': {  # Recurrent state
                    'weight': 'gamma',
                }
            },
            'phi': {
                'r': {  # Recurrent state
                    'weight': 'phi',
                }
            },
            'kappa': {
                'r': {  # Recurrent state
                    'weight': 'kappa',
                }
            },
            'delta': {
                'r': {  # Recurrent state
                    'weight': 'delta',
                }
            }
        }

        # tuned summation: pooling in h, w dimensions
        #############################################
        q_array = np.ones(self.q_shape) / np.prod(self.q_shape)
        if 'Q' in self.lesions:
            q_array = np.zeros_like(q_array).astype(np.float32)
            print 'Lesioning CRF excitation.'
        setattr(
            self, self.weight_dict['Q']['r']['weight'],
            tf.get_variable(name=self.weight_dict['Q']['r']['weight'],
                            dtype=self.dtype,
                            initializer=q_array.astype(np.float32),
                            trainable=False))

        # untuned suppression: reduction across feature axis
        ####################################################
        u_array = np.ones(self.u_shape) / np.prod(self.u_shape)
        if 'U' in self.lesions:
            u_array = np.zeros_like(u_array).astype(np.float32)
            print 'Lesioning CRF inhibition.'
        setattr(
            self, self.weight_dict['U']['r']['weight'],
            tf.get_variable(name=self.weight_dict['U']['r']['weight'],
                            dtype=self.dtype,
                            initializer=u_array.astype(np.float32),
                            trainable=False))

        # weakly tuned summation: pooling in h, w dimensions
        #############################################
        if isinstance(self.p_shape[0], list) and 'P' not in self.lesions:
            # VGG-style filters
            for pidx, pext in enumerate(self.p_shape):
                if pidx == 0:
                    it_key = self.weight_dict['P']['r']['weight']
                else:
                    self.weight_dict['P']['r']['weight_%s' %
                                               pidx] = 'p_r_%s' % pidx
                    it_key = self.weight_dict['P']['r']['weight_%s' % pidx]
                setattr(
                    self, it_key,
                    tf.get_variable(
                        name=it_key,
                        dtype=self.dtype,
                        initializer=initialization.xavier_initializer(
                            shape=pext, uniform=self.normal_initializer),
                        trainable=True))
        else:
            p_array = np.ones(self.p_shape)
            p_array[self.SSN // 2 -
                    py_utils.ifloor(self.SRF / 2.0):self.SSF // 2 +
                    py_utils.iceil(self.SSN / 2.0), self.SSN // 2 -
                    py_utils.ifloor(self.SRF / 2.0):self.SSF // 2 +
                    py_utils.iceil(self.SSN / 2.0), :,  # exclude CRF!
                    :] = 0.0
            p_array = p_array / p_array.sum()
            if 'P' in self.lesions:
                print 'Lesioning near eCRF.'
                p_array = np.zeros_like(p_array).astype(np.float32)

            # Association field is fully learnable
            if self.association_field and 'P' not in self.lesions:
                setattr(
                    self, self.weight_dict['P']['r']['weight'],
                    tf.get_variable(
                        name=self.weight_dict['P']['r']['weight'],
                        dtype=self.dtype,
                        initializer=initialization.xavier_initializer(
                            shape=self.p_shape,
                            uniform=self.normal_initializer),
                        trainable=True))
            else:
                setattr(
                    self, self.weight_dict['P']['r']['weight'],
                    tf.get_variable(name=self.weight_dict['P']['r']['weight'],
                                    dtype=self.dtype,
                                    initializer=p_array.astype(np.float32),
                                    trainable=False))

        # Connectivity tensors -- Q/P/T
        if 'Q' in self.lesions:
            print 'Lesioning CRF excitation connectivity.'
            setattr(
                self, self.weight_dict['Q']['r']['tuning'],
                tf.get_variable(name=self.weight_dict['Q']['r']['tuning'],
                                dtype=self.dtype,
                                trainable=False,
                                initializer=np.zeros(self.tuning_shape).astype(
                                    np.float32)))
        else:
            setattr(
                self, self.weight_dict['Q']['r']['tuning'],
                tf.get_variable(name=self.weight_dict['Q']['r']['tuning'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=initialization.xavier_initializer(
                                    shape=self.tuning_shape,
                                    uniform=self.normal_initializer,
                                    mask=None)))
        # Gate weights
        setattr(
            self, self.weight_dict['I']['r']['weight'],
            tf.get_variable(name=self.weight_dict['I']['r']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.i_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['I']['f']['weight'],
            tf.get_variable(name=self.weight_dict['I']['f']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.i_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['I']['r']['bias'],
            tf.get_variable(name=self.weight_dict['I']['r']['bias'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=tf.ones(self.bias_shape)))

        # Output
        setattr(
            self, self.weight_dict['O']['r']['weight'],
            tf.get_variable(name=self.weight_dict['O']['r']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.o_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['O']['f']['weight'],
            tf.get_variable(name=self.weight_dict['O']['f']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.o_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(  # TODO: smart initialization of these
            self, self.weight_dict['O']['r']['bias'],
            tf.get_variable(name=self.weight_dict['O']['r']['bias'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=tf.ones(self.bias_shape)))

        # Degree of freedom weights (vectors)
        w_array = np.ones([1, 1, 1, self.k]).astype(np.float32)
        b_array = np.zeros([1, 1, 1, self.k]).astype(np.float32)

        # Divisive params
        self.alpha = tf.get_variable(name='alpha', initializer=w_array)
        self.beta = tf.get_variable(name='beta', initializer=w_array)

        # Subtractive params
        self.mu = tf.get_variable(name='mu', initializer=b_array)
        self.nu = tf.get_variable(name='nu', initializer=b_array)
        if self.zeta:
            self.zeta = tf.get_variable(name='zeta', initializer=w_array)
        else:
            self.zeta = tf.constant(1.)
        if self.gamma:
            self.gamma = tf.get_variable(name='gamma', initializer=w_array)
        else:
            self.gamma = tf.constant(1.)
        if self.delta:
            self.delta = tf.get_variable(name='delta', initializer=w_array)
        else:
            self.delta = tf.constant(1.)
        if self.xi:
            self.xi = tf.get_variable(name='xi', initializer=w_array)
        else:
            self.xi = tf.constant(1.)
        if self.multiplicative_excitation:
            self.kappa = tf.get_variable(name='kappa', initializer=w_array)
            self.omega = tf.get_variable(name='omega', initializer=w_array)
        else:
            self.kappa = tf.constant(1.)
            self.omega = tf.constant(1.)
    def prepare_tensors(self):
        """ Prepare recurrent/forward weight matrices."""
        self.weight_dict = {  # Weights lower/activity upper
            'U': {
                'r': {
                    'weight': 'u_r',
                    'activity': 'U_r'
                }
            },
            'T': {
                'r': {
                    'weight': 't_r',
                    'activity': 'T_r',
                    'tuning': 't_t'
                }
            },
            'P': {
                'r': {
                    'weight': 'p_r',
                    'activity': 'P_r',
                    'tuning': 'p_t'
                }
            },
            'Q': {
                'r': {
                    'weight': 'q_r',
                    'activity': 'Q_r',
                    'tuning': 'q_t'
                }
            },
            'I': {
                'r': {  # Recurrent state
                    'weight': 'i_r',
                    'activity': 'I_r'
                },
                'f': {  # Recurrent state
                    'weight': 'i_f',
                    'activity': 'I_f'
                },
            },
            'O': {
                'r': {  # Recurrent state
                    'weight': 'o_r',
                    'activity': 'O_r'
                },
                'f': {  # Recurrent state
                    'weight': 'o_f',
                    'activity': 'O_f'
                },
            },
            'xi': {
                'r': {  # Recurrent state
                    'weight': 'xi',
                }
            },
            'alpha': {
                'r': {  # Recurrent state
                    'weight': 'alpha',
                }
            },
            'beta': {
                'r': {  # Recurrent state
                    'weight': 'beta',
                }
            },
            'mu': {
                'r': {  # Recurrent state
                    'weight': 'mu',
                }
            },
            'nu': {
                'r': {  # Recurrent state
                    'weight': 'nu',
                }
            },
            'zeta': {
                'r': {  # Recurrent state
                    'weight': 'zeta',
                }
            },
            'gamma': {
                'r': {  # Recurrent state
                    'weight': 'gamma',
                }
            },
            'delta': {
                'r': {  # Recurrent state
                    'weight': 'delta',
                }
            }
        }

        # tuned summation: pooling in h, w dimensions
        #############################################
        q_array = np.ones(self.q_shape) / np.prod(self.q_shape)
        setattr(
            self, self.weight_dict['Q']['r']['weight'],
            tf.get_variable(name=self.weight_dict['Q']['r']['weight'],
                            dtype=self.dtype,
                            initializer=q_array.astype(np.float32),
                            trainable=False))

        # untuned suppression: reduction across feature axis
        ####################################################
        u_array = np.ones(self.u_shape) / np.prod(self.u_shape)
        setattr(
            self, self.weight_dict['U']['r']['weight'],
            tf.get_variable(name=self.weight_dict['U']['r']['weight'],
                            dtype=self.dtype,
                            initializer=u_array.astype(np.float32),
                            trainable=False))

        # weakly tuned summation: pooling in h, w dimensions
        #############################################
        p_array = np.ones(self.p_shape)
        p_array[self.SSN // 2 - py_utils.ifloor(self.SRF / 2.0):self.SSN // 2 +
                py_utils.iceil(self.SRF / 2.0),
                self.SSN // 2 - py_utils.ifloor(self.SRF / 2.0):self.SSN // 2 +
                py_utils.iceil(self.SRF / 2.0), :,  # exclude CRF!
                :] = 0.0
        p_array = p_array / p_array.sum()

        setattr(
            self, self.weight_dict['P']['r']['weight'],
            tf.get_variable(name=self.weight_dict['P']['r']['weight'],
                            dtype=self.dtype,
                            initializer=p_array.astype(np.float32),
                            trainable=False))

        # weakly tuned suppression: pooling in h, w dimensions
        ###############################################
        t_array = np.ones(self.t_shape)
        t_array[self.SSF // 2 - py_utils.ifloor(self.SSN / 2.0):self.SSF // 2 +
                py_utils.iceil(self.SSN / 2.0),
                self.SSF // 2 - py_utils.ifloor(self.SSN / 2.0):self.SSF // 2 +
                py_utils.iceil(self.SSN / 2.0), :,  # exclude near surround!
                :] = 0.0
        t_array = t_array / t_array.sum()
        setattr(
            self, self.weight_dict['T']['r']['weight'],
            tf.get_variable(name=self.weight_dict['T']['r']['weight'],
                            dtype=self.dtype,
                            initializer=t_array.astype(np.float32),
                            trainable=False))

        # Connectivity tensors -- Q/P/T
        setattr(
            self, self.weight_dict['Q']['r']['tuning'],
            tf.get_variable(name=self.weight_dict['Q']['r']['tuning'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.tuning_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['P']['r']['tuning'],
            tf.get_variable(name=self.weight_dict['P']['r']['tuning'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.tuning_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['T']['r']['tuning'],
            tf.get_variable(name=self.weight_dict['T']['r']['tuning'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.tuning_shape,
                                uniform=self.normal_initializer,
                                mask=None)))

        # Input
        setattr(
            self, self.weight_dict['I']['r']['weight'],
            tf.get_variable(name=self.weight_dict['I']['r']['weight'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.i_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['I']['f']['weight'],
            tf.get_variable(name=self.weight_dict['I']['f']['weight'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.i_shape,
                                uniform=self.normal_initializer,
                                mask=None)))

        # Output
        setattr(
            self, self.weight_dict['O']['r']['weight'],
            tf.get_variable(name=self.weight_dict['O']['r']['weight'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.o_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['O']['f']['weight'],
            tf.get_variable(name=self.weight_dict['O']['f']['weight'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.o_shape,
                                uniform=self.normal_initializer,
                                mask=None)))

        # Vector weights
        w_array = np.ones([1, 1, 1, self.k]).astype(np.float32)
        b_array = np.zeros([1, 1, 1, self.k]).astype(np.float32)
        self.xi = tf.get_variable(name='xi', initializer=w_array)
        self.alpha = tf.get_variable(name='alpha', initializer=w_array)
        self.beta = tf.get_variable(name='beta', initializer=w_array)
        self.mu = tf.get_variable(name='mu', initializer=b_array)
        self.nu = tf.get_variable(name='nu', initializer=b_array)
        self.zeta = tf.get_variable(name='zeta', initializer=w_array)
        self.gamma = tf.get_variable(name='gamma', initializer=w_array)
        self.delta = tf.get_variable(name='delta', initializer=w_array)
Beispiel #7
0
    def prepare_tensors(self):
        """ Prepare recurrent/forward weight matrices."""
        self.weight_dict = {  # Weights lower/activity upper
            'U': {
                'r': {
                    'weight': 'u_r',
                    'activity': 'U_r'
                },
                'f': {
                    'weight': 'u_f',
                    'bias': 'ub_f',
                    'activity': 'U_f'
                }
            },
            'T': {
                'r': {
                    'weight': 't_r',
                    'activity': 'T_r'
                },
                'f': {
                    'weight': 't_f',
                    'bias': 'tb_f',
                    'activity': 'T_f'
                }
            },
            'P': {
                'r': {
                    'weight': 'p_r',
                    'activity': 'P_r'
                },
                'f': {
                    'weight': 'p_f',
                    'bias': 'pb_f',
                    'activity': 'P_f'
                }
            },
            'Q': {
                'r': {
                    'weight': 'q_r',
                    'activity': 'Q_r'
                },
                'f': {
                    'weight': 'q_f',
                    'bias': 'qb_f',
                    'activity': 'Q_f'
                }
            },
            'I': {
                'r': {  # Recurrent state
                    'weight': 'i_r',
                    'activity': 'I_r'
                }
            },
            'O': {
                'r': {  # Recurrent state
                    'weight': 'o_r',
                    'activity': 'O_r'
                }
            }
        }

        # tuned summation: pooling in h, w dimensions
        #############################################
        setattr(
            self, self.weight_dict['Q']['r']['weight'],
            tf.get_variable(name=self.weight_dict['Q']['r']['weight'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.q_shape,
                                uniform=self.normal_initializer,
                                mask=None)))

        # untuned suppression: reduction across feature axis
        ####################################################
        setattr(
            self, self.weight_dict['U']['r']['weight'],
            tf.get_variable(name=self.weight_dict['U']['r']['weight'],
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=self.u_shape,
                                uniform=self.normal_initializer,
                                mask=None)))

        # tuned summation: pooling in h, w dimensions
        #############################################
        p_array = np.zeros(self.p_shape)
        for pdx in range(self.k):
            p_array[:self.SSN, :self.SSN, pdx, pdx] = 1.0
        p_array[self.SSN // 2 - py_utils.ifloor(self.SRF / 2.0):self.SSN // 2 +
                py_utils.iceil(self.SRF / 2.0),
                self.SSN // 2 - py_utils.ifloor(self.SRF / 2.0):self.SSN // 2 +
                py_utils.iceil(self.SRF / 2.0), :,  # exclude CRF!
                :] = 0.0

        setattr(
            self, self.weight_dict['P']['r']['weight'],
            tf.get_variable(name=self.weight_dict['P']['r']['weight'],
                            dtype=self.dtype,
                            initializer=p_array.astype(np.float32),
                            trainable=False))

        # tuned suppression: pooling in h, w dimensions
        ###############################################
        t_array = np.zeros(self.t_shape)
        for tdx in range(self.k):
            t_array[:self.SSF, :self.SSF, tdx, tdx] = 1.0
        t_array[self.SSF // 2 - py_utils.ifloor(self.SSN / 2.0):self.SSF // 2 +
                py_utils.iceil(self.SSN / 2.0),
                self.SSF // 2 - py_utils.ifloor(self.SSN / 2.0):self.SSF // 2 +
                py_utils.iceil(self.SSN / 2.0), :,  # exclude near surround!
                :] = 0.0
        setattr(
            self, self.weight_dict['T']['r']['weight'],
            tf.get_variable(name=self.weight_dict['T']['r']['weight'],
                            dtype=self.dtype,
                            initializer=t_array.astype(np.float32),
                            trainable=False))

        # Scalar weights
        self.xi = tf.get_variable(name='xi', initializer=1.)
        self.alpha = tf.get_variable(name='alpha', initializer=1.)
        self.beta = tf.get_variable(name='beta', initializer=1.)
        self.mu = tf.get_variable(name='mu', initializer=1.)
        self.nu = tf.get_variable(name='nu', initializer=1.)
        self.zeta = tf.get_variable(name='zeta', initializer=1.)
        self.gamma = tf.get_variable(name='gamma', initializer=1.)
        self.delta = tf.get_variable(name='delta', initializer=1.)
        self.eps = tf.get_variable(name='eps', initializer=1.)
        self.eta = tf.get_variable(name='eta', initializer=1.)
        self.sig = tf.get_variable(name='sig', initializer=1.)
        self.tau = tf.get_variable(name='tau', initializer=1.)
Beispiel #8
0
def contextual_div_norm_2d(x,
                           CRF_sum_window,
                           CRF_sup_window,
                           eCRF_sum_window,
                           eCRF_sup_window,
                           strides,
                           padding,
                           gamma=None,
                           beta=None,
                           eps=1.0,
                           scope="dn",
                           name="dn_out",
                           return_mean=False):
    """Applies divisive normalization on CNN feature maps.
    Collect mean and variances on x on a local window across channels.
    And apply normalization as below:
      x_ = gamma * (x - mean) / sqrt(var + eps) + beta
    https://github.com/renmengye/div-norm/blob/master/div_norm.py

    Args:
      x: Input tensor, [B, H, W, C].
      sum_window: Summation window size, [H_sum, W_sum].
      sup_window: Suppression window size, [H_sup, W_sup].
      gamma: Scaling parameter.
      beta: Bias parameter.
      eps: Denominator bias.
      return_mean: Whether to also return the computed mean.

    Returns:
      normed: Divisive-normalized variable.
      mean: Mean used for normalization (optional).
    """
    if not isinstance(CRF_sum_window, list):
        CRF_sum_window = list(np.repeat(CRF_sum_window, 2))
    if not isinstance(CRF_sup_window, list):
        CRF_sup_window = list(np.repeat(CRF_sup_window, 2))
    if not isinstance(eCRF_sum_window, list):
        eCRF_sum_window = list(np.repeat(eCRF_sum_window, 2))
    if not isinstance(eCRF_sup_window, list):
        eCRF_sup_window = list(np.repeat(eCRF_sup_window, 2))
    k = int(x.get_shape()[-1])
    with tf.variable_scope(scope):

        # Q
        q_array = np.ones((CRF_sum_window + [k, k]))
        q_array /= q_array.sum()
        w_sum = tf.cast(tf.constant(q_array), tf.float32)
        # U
        u_array = np.ones((CRF_sum_window + [k, 1]))
        u_array /= u_array.sum()
        w_sup = tf.cast(tf.constant(u_array), tf.float32)
        CRF_sum_window = CRF_sum_window[0]
        CRF_sup_window = CRF_sup_window[0]
        # P
        p_shape = eCRF_sum_window + [k, k]
        eCRF_sum_window = eCRF_sum_window[0]
        p_array = np.zeros(p_shape)
        for pdx in range(k):
            p_array[:eCRF_sum_window, :eCRF_sum_window, pdx, pdx] = 1.0
        p_array[eCRF_sum_window // 2 -
                py_utils.ifloor(CRF_sum_window / 2.0):eCRF_sum_window // 2 +
                py_utils.iceil(CRF_sum_window / 2.0), CRF_sum_window // 2 -
                py_utils.ifloor(CRF_sum_window / 2.0):eCRF_sum_window // 2 +
                py_utils.iceil(CRF_sum_window / 2.0), :,  # exclude CRF!
                :] = 0.0
        w_esum = tf.cast(tf.constant(p_array) / p_array.sum(), tf.float32)

        # T
        t_shape = eCRF_sup_window + [k, k]
        eCRF_sup_window = eCRF_sup_window[0]
        t_array = np.zeros(t_shape)
        for tdx in range(k):
            t_array[:eCRF_sup_window, :eCRF_sup_window, tdx, tdx] = 1.0
        t_array[eCRF_sup_window // 2 -
                py_utils.ifloor(CRF_sup_window / 2.0):eCRF_sup_window // 2 +
                py_utils.iceil(CRF_sup_window / 2.0), eCRF_sup_window // 2 -
                py_utils.ifloor(CRF_sup_window / 2.0):eCRF_sup_window // 2 +
                py_utils.iceil(CRF_sup_window /
                               2.0), :,  # exclude near surround!
                :] = 0.0

        w_esup = tf.cast(tf.constant(t_array) / t_array.sum(), tf.float32)

        # SUM
        x_mean_CRF = tf.nn.conv2d(x, w_sum, strides=strides, padding=padding)
        x_mean_eCRF = tf.nn.conv2d(x, w_esum, strides=strides, padding=padding)
        normed = x - x_mean_CRF - x_mean_eCRF
        x2 = tf.square(normed)

        # SUP
        x2_mean_CRF = tf.nn.conv2d(x2, w_sup, strides=strides, padding=padding)
        x2_mean_eCRF = tf.nn.conv2d(x2,
                                    w_esup,
                                    strides=strides,
                                    padding=padding)
        denom = tf.sqrt(x2_mean_CRF + x2_mean_eCRF + eps)
        normed = normed / denom
        if gamma is not None:
            normed *= gamma
        if beta is not None:
            normed += beta
    normed = tf.identity(normed, name=name)
    if return_mean:
        return normed, x2
    else:
        return normed
Beispiel #9
0
    def prepare_tensors(self):
        """ Prepare recurrent/forward weight matrices."""
        self.weight_dict = {  # Weights lower/activity upper
            'U': {
                'r': {
                    'weight': 'u_r',
                    'activity': 'U_r'
                }
            },
            'T': {
                'r': {
                    'weight': 't_r',
                    'activity': 'T_r',
                    'tuning': 't_t'
                }
            },
            'P': {
                'r': {
                    'weight': 'p_r',
                    'activity': 'P_r',
                    'tuning': 'p_t'
                }
            },
            'Q': {
                'r': {
                    'weight': 'q_r',
                    'activity': 'Q_r',
                    'tuning': 'q_t'
                }
            },
            'I': {
                'r': {  # Recurrent state
                    'weight': 'i_r',
                    'bias': 'i_b',
                    'activity': 'I_r'
                },
                'f': {  # Recurrent state
                    'weight': 'i_f',
                    'activity': 'I_f'
                },
            },
            'O': {
                'r': {  # Recurrent state
                    'weight': 'o_r',
                    'bias': 'o_b',
                    'activity': 'O_r'
                },
                'f': {  # Recurrent state
                    'weight': 'o_f',
                    'activity': 'O_f'
                },
            },
            'xi': {
                'r': {  # Recurrent state
                    'weight': 'xi',
                }
            },
            'alpha': {
                'r': {  # Recurrent state
                    'weight': 'alpha',
                }
            },
            'beta': {
                'r': {  # Recurrent state
                    'weight': 'beta',
                }
            },
            'mu': {
                'r': {  # Recurrent state
                    'weight': 'mu',
                }
            },
            'nu': {
                'r': {  # Recurrent state
                    'weight': 'nu',
                }
            },
            'zeta': {
                'r': {  # Recurrent state
                    'weight': 'zeta',
                }
            },
            'gamma': {
                'r': {  # Recurrent state
                    'weight': 'gamma',
                }
            },
            'delta': {
                'r': {  # Recurrent state
                    'weight': 'delta',
                }
            }
        }

        # tuned summation: pooling in h, w dimensions
        #############################################
        q_array = np.ones(self.q_shape) / np.prod(self.q_shape)
        if 'Q' in self.lesions:
            q_array = np.zeros_like(q_array).astype(np.float32)
            print 'Lesioning CRF excitation.'
        setattr(
            self, self.weight_dict['Q']['r']['weight'],
            tf.get_variable(name=self.weight_dict['Q']['r']['weight'],
                            dtype=self.dtype,
                            initializer=q_array.astype(np.float32),
                            trainable=False))

        # untuned suppression: reduction across feature axis
        ####################################################
        u_array = np.ones(self.u_shape) / np.prod(self.u_shape)
        if 'U' in self.lesions:
            u_array = np.zeros_like(u_array).astype(np.float32)
            print 'Lesioning CRF inhibition.'
        setattr(
            self, self.weight_dict['U']['r']['weight'],
            tf.get_variable(name=self.weight_dict['U']['r']['weight'],
                            dtype=self.dtype,
                            initializer=u_array.astype(np.float32),
                            trainable=False))

        # weakly tuned summation: pooling in h, w dimensions
        #############################################
        p_array = np.ones(self.p_shape)
        #Try removing CRF punching
        if self.exclude_CRF:
            # Punch out the CRF
            p_array[self.SSN // 2 -
                    py_utils.ifloor(self.SRF / 2.0):self.SSF // 2 +
                    py_utils.iceil(self.SSN / 2.0), self.SSN // 2 -
                    py_utils.ifloor(self.SRF / 2.0):self.SSF // 2 +
                    py_utils.iceil(self.SSN / 2.0), :,  # exclude CRF!
                    :] = 0.0

        p_array = p_array / p_array.sum()

        if 'P' in self.lesions:
            print 'Lesioning near eCRF.'
            p_array = np.zeros_like(p_array).astype(np.float32)

        # Association field is fully learnable
        if self.association_field and 'P' not in self.lesions:
            setattr(
                self, self.weight_dict['P']['r']['weight'],
                tf.get_variable(name=self.weight_dict['P']['r']['weight'],
                                dtype=self.dtype,
                                initializer=initialization.xavier_initializer(
                                    shape=self.p_shape,
                                    uniform=self.normal_initializer),
                                trainable=True))
        else:
            setattr(
                self, self.weight_dict['P']['r']['weight'],
                tf.get_variable(name=self.weight_dict['P']['r']['weight'],
                                dtype=self.dtype,
                                initializer=p_array.astype(np.float32),
                                trainable=False))

        # weakly tuned suppression: pooling in h, w dimensions
        ###############################################
        t_array = np.ones(self.t_shape)
        #Try without punching CRF
        if self.exclude_CRF:
            # Punch out the CRF
            t_array[self.SSF // 2 -
                    py_utils.ifloor(self.SSN / 2.0):self.SSF // 2 +
                    py_utils.iceil(self.SSN / 2.0), self.SSF // 2 -
                    py_utils.ifloor(self.SSN / 2.0):self.SSF // 2 +
                    py_utils.iceil(self.SSN /
                                   2.0), :,  # exclude near surround!
                    :] = 0.0

        t_array = t_array / t_array.sum()
        if 'T' in self.lesions:
            print 'Lesioning Far eCRF.'
            t_array = np.zeros_like(t_array).astype(np.float32)

        #Always set full_far_eCRF to True/initialize with Xavier
        if self.full_far_eCRF:
            setattr(
                self, self.weight_dict['T']['r']['weight'],
                tf.get_variable(name=self.weight_dict['T']['r']['weight'],
                                dtype=self.dtype,
                                initializer=initialization.xavier_initializer(
                                    shape=self.p_shape,
                                    uniform=self.normal_initializer),
                                trainable=True))
        else:
            setattr(
                self, self.weight_dict['T']['r']['weight'],
                tf.get_variable(name=self.weight_dict['T']['r']['weight'],
                                dtype=self.dtype,
                                initializer=t_array.astype(np.float32),
                                trainable=False))

        # Connectivity tensors -- Q/P/T
        if 'Q' in self.lesions:
            print 'Lesioning CRF excitation connectivity.'
            setattr(
                self, self.weight_dict['Q']['r']['tuning'],
                tf.get_variable(name=self.weight_dict['Q']['r']['tuning'],
                                dtype=self.dtype,
                                trainable=False,
                                initializer=np.zeros(self.tuning_shape).astype(
                                    np.float32)))
        else:
            setattr(
                self, self.weight_dict['Q']['r']['tuning'],
                tf.get_variable(name=self.weight_dict['Q']['r']['tuning'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=initialization.xavier_initializer(
                                    shape=self.tuning_shape,
                                    uniform=self.normal_initializer,
                                    mask=None)))
        if not self.association_field:
            # Need a tuning tensor for near surround
            if 'P' in self.lesions:
                print 'Lesioning near eCRF connectivity.'
                setattr(
                    self, self.weight_dict['P']['r']['tuning'],
                    tf.get_variable(name=self.weight_dict['P']['r']['tuning'],
                                    dtype=self.dtype,
                                    trainable=False,
                                    initializer=np.zeros(
                                        self.tuning_shape).astype(np.float32)))
            else:
                setattr(
                    self, self.weight_dict['P']['r']['tuning'],
                    tf.get_variable(
                        name=self.weight_dict['P']['r']['tuning'],
                        dtype=self.dtype,
                        trainable=True,
                        initializer=initialization.xavier_initializer(
                            shape=self.tuning_shape,
                            uniform=self.normal_initializer,
                            mask=None)))

        #Again, full_far_eCRF should be set to True for now
        import ipdb
        ipdb.set_trace()
        if not self.full_far_eCRF:
            # Need a tuning tensor for near surround
            if 'T' in self.lesions:
                print 'Lesioning far eCRF connectivity.'
                setattr(
                    self, self.weight_dict['T']['r']['tuning'],
                    tf.get_variable(name=self.weight_dict['T']['r']['tuning'],
                                    dtype=self.dtype,
                                    trainable=False,
                                    initializer=np.zeros(
                                        self.tuning_shape).astype(np.float32)))
        else:
            setattr(
                self, self.weight_dict['T']['r']['tuning'],
                tf.get_variable(name=self.weight_dict['T']['r']['tuning'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=initialization.xavier_initializer(
                                    shape=self.tuning_shape,
                                    uniform=self.normal_initializer,
                                    mask=None)))

        # Input
        setattr(
            self, self.weight_dict['I']['r']['weight'],
            tf.get_variable(name=self.weight_dict['I']['r']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.i_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['I']['f']['weight'],
            tf.get_variable(name=self.weight_dict['I']['f']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.i_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['I']['r']['bias'],
            tf.get_variable(name=self.weight_dict['I']['r']['bias'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=tf.ones(self.bias_shape)))

        # Output
        setattr(
            self, self.weight_dict['O']['r']['weight'],
            tf.get_variable(name=self.weight_dict['O']['r']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.o_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['O']['f']['weight'],
            tf.get_variable(name=self.weight_dict['O']['f']['weight'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=initialization.xavier_initializer(
                                shape=self.o_shape,
                                uniform=self.normal_initializer,
                                mask=None)))
        setattr(
            self, self.weight_dict['O']['r']['bias'],
            tf.get_variable(name=self.weight_dict['O']['r']['bias'],
                            dtype=self.dtype,
                            trainable=True,
                            initializer=tf.ones(self.bias_shape)))

        # Vector weights
        w_array = np.ones([1, 1, 1, self.k]).astype(np.float32)
        b_array = np.zeros([1, 1, 1, self.k]).astype(np.float32)
        self.xi = tf.get_variable(name='xi', initializer=w_array)
        self.alpha = tf.get_variable(name='alpha', initializer=w_array)
        self.beta = tf.get_variable(name='beta', initializer=w_array)
        self.mu = tf.get_variable(name='mu', initializer=b_array)
        self.nu = tf.get_variable(name='nu', initializer=b_array)
        self.zeta = tf.get_variable(name='zeta', initializer=w_array)
        self.gamma = tf.get_variable(name='gamma', initializer=w_array)
        self.delta = tf.get_variable(name='delta', initializer=w_array)
Beispiel #10
0
    def prepare_tensors(self):
        """ Prepare recurrent/forward weight matrices."""
        self.weight_dict = {  # Weights lower/activity upper
            'U': {
                'r': {
                    'weight': 'u_r',
                    'activity': 'U_r'
                    },
                'f': {
                    'weight': 'u_f',
                    'bias': 'ub_f',
                    'activity': 'U_f'
                    }
                },
            'T': {
                'r': {
                    'weight': 't_r',
                    'activity': 'T_r'
                    },
                'f': {
                    'weight': 't_f',
                    'bias': 'tb_f',
                    'activity': 'T_f'
                    }
                },
            'P': {
                'r': {
                    'weight': 'p_r',
                    'activity': 'P_r'
                    },
                'f': {
                    'weight': 'p_f',
                    'bias': 'pb_f',
                    'activity': 'P_f'
                    }
                },
            'Q': {
                'r': {
                    'weight': 'q_r',
                    'activity': 'Q_r'
                    },
                'f': {
                    'weight': 'q_f',
                    'bias': 'qb_f',
                    'activity': 'Q_f'
                    }
                },
            'I': {
                'r': {  # Recurrent state
                    'weight': 'i_r',
                    'bias': 'ib_r',
                    'activity': 'I_r'
                }
            },
            'O': {
                'r': {  # Recurrent state
                    'weight': 'o_r',
                    'bias': 'ob_r',
                    'activity': 'O_r'
                }
            }
        }

        # tuned summation: pooling in h, w dimensions
        #############################################
        setattr(
            self,
            self.weight_dict['Q']['r']['weight'],
            tf.get_variable(
                name=self.weight_dict['Q']['r']['weight'],
                dtype=self.dtype,
                initializer=initialization.xavier_initializer(
                    shape=self.q_shape,
                    uniform=self.normal_initializer,
                    mask=None)))

        # untuned suppression: reduction across feature axis
        ####################################################
        setattr(
            self,
            self.weight_dict['U']['r']['weight'],
            tf.get_variable(
                name=self.weight_dict['U']['r']['weight'],
                dtype=self.dtype,
                initializer=initialization.xavier_initializer(
                    shape=self.u_shape,
                    uniform=self.normal_initializer,
                    mask=None)))

        # tuned summation: pooling in h, w dimensions
        #############################################
        p_array = np.zeros(self.p_shape)
        for pdx in range(self.k):
            p_array[:self.SSN, :self.SSN, pdx, pdx] = 1.0
        p_array[
            self.SSN // 2 - py_utils.ifloor(
                self.SRF / 2.0):self.SSN // 2 + py_utils.iceil(
                self.SRF / 2.0),
            self.SSN // 2 - py_utils.ifloor(
                self.SRF / 2.0):self.SSN // 2 + py_utils.iceil(
                self.SRF / 2.0),
            :,  # exclude CRF!
            :] = 0.0

        setattr(
            self,
            self.weight_dict['P']['r']['weight'],
            tf.get_variable(
                name=self.weight_dict['P']['r']['weight'],
                dtype=self.dtype,
                initializer=initialization.xavier_initializer(
                    shape=self.p_shape,
                    uniform=self.normal_initializer,
                    mask=p_array)))

        # tuned suppression: pooling in h, w dimensions
        ###############################################
        t_array = np.zeros(self.t_shape)
        for tdx in range(self.k):
            t_array[:self.SSF, :self.SSF, tdx, tdx] = 1.0
        t_array[
            self.SSF // 2 - py_utils.ifloor(
                self.SSN / 2.0):self.SSF // 2 + py_utils.iceil(
                self.SSN / 2.0),
            self.SSF // 2 - py_utils.ifloor(
                self.SSN / 2.0):self.SSF // 2 + py_utils.iceil(
                self.SSN / 2.0),
            :,  # exclude near surround!
            :] = 0.0
        setattr(
            self,
            self.weight_dict['T']['r']['weight'],
            tf.get_variable(
                name=self.weight_dict['T']['r']['weight'],
                dtype=self.dtype,
                initializer=initialization.xavier_initializer(
                    shape=self.t_shape,
                    uniform=self.normal_initializer,
                    mask=t_array)))

        if self.model_version != 'no_input_facing':
            # Also create input-facing weight matrices
            # Q
            setattr(
                self,
                self.weight_dict['Q']['f']['weight'],
                tf.get_variable(
                    name=self.weight_dict['Q']['f']['weight'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=self.q_shape,
                        uniform=self.normal_initializer)))
            setattr(
                self,
                self.weight_dict['Q']['f']['bias'],
                tf.get_variable(
                    name=self.weight_dict['Q']['f']['bias'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=[self.q_shape[-1]],
                        uniform=self.normal_initializer)))

            # U
            setattr(
                self,
                self.weight_dict['U']['f']['weight'],
                tf.get_variable(
                    name=self.weight_dict['U']['f']['weight'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=self.u_shape,
                        uniform=self.normal_initializer)))
            setattr(
                self,
                self.weight_dict['U']['f']['bias'],
                tf.get_variable(
                    name=self.weight_dict['U']['f']['bias'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        [self.u_shape[-1]],
                        uniform=self.normal_initializer)))

            # P
            setattr(
                self,
                self.weight_dict['P']['f']['weight'],
                tf.get_variable(
                    name=self.weight_dict['P']['f']['weight'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        self.p_shape,
                        uniform=self.normal_initializer,
                        mask=p_array)))
            setattr(
                self,
                self.weight_dict['P']['f']['bias'],
                tf.get_variable(
                    name=self.weight_dict['P']['f']['bias'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        [self.p_shape[-1]],
                        uniform=self.normal_initializer,
                        mask=None)))

            # T
            setattr(
                self,
                self.weight_dict['T']['f']['weight'],
                tf.get_variable(
                    name=self.weight_dict['T']['f']['weight'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=self.t_shape,
                        uniform=self.normal_initializer,
                        mask=t_array)))
            setattr(
                self,
                self.weight_dict['T']['f']['bias'],
                tf.get_variable(
                    name=self.weight_dict['T']['f']['bias'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=[self.t_shape[-1]],
                        uniform=self.normal_initializer,
                        mask=None)))

        if self.model_version == 'full_with_cell_states':
            # Input
            setattr(
                self,
                self.weight_dict['I']['r']['weight'],
                tf.get_variable(
                    name=self.weight_dict['I']['r']['weight'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=self.i_shape,
                        uniform=self.normal_initializer,
                        mask=None)))
            setattr(
                self,
                self.weight_dict['I']['r']['bias'],
                tf.get_variable(
                    name=self.weight_dict['I']['r']['bias'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=[self.k],
                        uniform=self.normal_initializer,
                        mask=None)))

            # Output
            setattr(
                self,
                self.weight_dict['O']['r']['weight'],
                tf.get_variable(
                    name=self.weight_dict['O']['r']['weight'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=self.o_shape,
                        uniform=self.normal_initializer,
                        mask=None)))
            setattr(
                self,
                self.weight_dict['O']['r']['bias'],
                tf.get_variable(
                    name=self.weight_dict['O']['r']['bias'],
                    dtype=self.dtype,
                    initializer=initialization.xavier_initializer(
                        shape=[self.k],
                        uniform=self.normal_initializer,
                        mask=None)))

        # Scalar weights
        self.alpha = tf.get_variable(name='alpha', initializer=1.)
        self.tau = tf.get_variable(name='tau', initializer=1.)
        self.eta = tf.get_variable(name='eta', initializer=1.)
        self.omega = tf.get_variable(name='omega', initializer=1.)
        self.eps = tf.get_variable(name='eps', initializer=1.)
        self.gamma = tf.get_variable(name='gamma', initializer=1.)
Beispiel #11
0
    def prepare_tensors(self):
        """ Prepare recurrent/forward weight matrices.
        9 * k + (2 * k^2) params in the greek letters/gates.
        (np.prod([h, w, k]) / 2) - k params in the surround filter
        """
        self.weight_dict = {  # Weights lower/activity upper
            'P': {
                'r': {
                    'weight': 'p_r',
                    'activity': 'P_r',
                    'tuning': 'p_t',
                    # 'bias': 'i_b'
                }
            },
            'I': {
                'r': {  # Recurrent state
                    'weight': 'i_r',
                    'bias': 'i_b',
                    'activity': 'I_r'
                },
                # 'f': {  # Recurrent state
                #     'weight': 'i_f',
                #     'activity': 'I_f'
                # },
            },
            'O': {
                'r': {  # Recurrent state
                    'weight': 'o_r',
                    'bias': 'o_b',
                    'activity': 'O_r'
                },
                # 'f': {  # Recurrent state
                #     'weight': 'o_f',
                #     'activity': 'O_f'
                # },
            },
            'xi': {
                'r': {  # Recurrent state
                    'weight': 'xi',
                }
            },
            # 'alpha': {
            #     'r': {  # Recurrent state
            #         'weight': 'alpha',
            #     }
            # },
            'beta': {
                'r': {  # Recurrent state
                    'weight': 'beta',
                }
            },
            # 'mu': {
            #     'r': {  # Recurrent state
            #         'weight': 'mu',
            #     }
            # },
            'nu': {
                'r': {  # Recurrent state
                    'weight': 'nu',
                }
            },
            'zeta': {
                'r': {  # Recurrent state
                    'weight': 'zeta',
                }
            },
            'gamma': {
                'r': {  # Recurrent state
                    'weight': 'gamma',
                }
            },
            'phi': {
                'r': {  # Recurrent state
                    'weight': 'phi',
                }
            },
            'kappa': {
                'r': {  # Recurrent state
                    'weight': 'kappa',
                }
            },
            'rho': {
                'r': {  # Recurrent state
                    'weight': 'rho',
                }
            },
        }

        # weakly tuned summation: pooling in h, w dimensions
        #############################################
        with tf.variable_scope('contextual_circuit'):
            if isinstance(self.p_shape[0], list) and 'P' not in self.lesions:
                # VGG-style filters
                for pidx, pext in enumerate(self.p_shape):
                    if pidx == 0:
                        it_key = self.weight_dict['P']['r']['weight']
                    else:
                        self.weight_dict['P']['r']['weight_%s' %
                                                   pidx] = 'p_r_%s' % pidx
                        it_key = self.weight_dict['P']['r']['weight_%s' % pidx]
                    setattr(
                        self, it_key,
                        tf.get_variable(
                            name=it_key,
                            dtype=self.dtype,
                            initializer=initialization.xavier_initializer(
                                shape=pext, uniform=self.normal_initializer),
                            trainable=True))
            else:
                p_array = np.ones(self.p_shape)
                p_array[self.SSN // 2 -
                        py_utils.ifloor(self.SRF / 2.0):self.SSF // 2 +
                        py_utils.iceil(self.SSN / 2.0), self.SSN // 2 -
                        py_utils.ifloor(self.SRF / 2.0):self.SSF // 2 +
                        py_utils.iceil(self.SSN / 2.0), :,  # exclude CRF!
                        :] = 0.0
                p_array = p_array / p_array.sum()
                if 'P' in self.lesions:
                    print 'Lesioning near eCRF.'
                    p_array = np.zeros_like(p_array).astype(np.float32)

                # Association field is fully learnable
                if self.association_field and 'P' not in self.lesions:
                    setattr(
                        self,
                        self.weight_dict['P']['r']['weight'],
                        tf.get_variable(
                            name=self.weight_dict['P']['r']['weight'],
                            dtype=self.dtype,
                            # shape=self.p_shape,
                            initializer=initialization.xavier_initializer(
                                shape=self.p_shape,
                                uniform=self.normal_initializer),
                            trainable=True))
                else:
                    setattr(
                        self, self.weight_dict['P']['r']['weight'],
                        tf.get_variable(
                            name=self.weight_dict['P']['r']['weight'],
                            dtype=self.dtype,
                            initializer=p_array.astype(np.float32),
                            trainable=False))

            # Gate weights
            setattr(
                self, self.weight_dict['I']['r']['weight'],
                tf.get_variable(name=self.weight_dict['I']['r']['weight'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=initialization.xavier_initializer(
                                    shape=self.i_shape,
                                    uniform=self.normal_initializer,
                                    mask=None)))
            # setattr(
            #     self,
            #     self.weight_dict['I']['f']['weight'],
            #     tf.get_variable(
            #         name=self.weight_dict['I']['f']['weight'],
            #         dtype=self.dtype,
            #         trainable=True,
            #         initializer=initialization.xavier_initializer(
            #             shape=self.i_shape,
            #             uniform=self.normal_initializer,
            #             mask=None)))
            if self.gate_bias_init == 'chronos':
                bias_init = -tf.log(
                    tf.random_uniform(
                        self.bias_shape, minval=1, maxval=self.timesteps - 1))
            else:
                bias_init = tf.ones(self.bias_shape)
            setattr(
                self, self.weight_dict['I']['r']['bias'],
                tf.get_variable(name=self.weight_dict['I']['r']['bias'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=bias_init))

            # Output
            setattr(
                self, self.weight_dict['O']['r']['weight'],
                tf.get_variable(name=self.weight_dict['O']['r']['weight'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=initialization.xavier_initializer(
                                    shape=self.o_shape,
                                    uniform=self.normal_initializer,
                                    mask=None)))
            # setattr(
            #     self,
            #     self.weight_dict['O']['f']['weight'],
            #     tf.get_variable(
            #         name=self.weight_dict['O']['f']['weight'],
            #         dtype=self.dtype,
            #         trainable=True,
            #         initializer=initialization.xavier_initializer(
            #             shape=self.o_shape,
            #             uniform=self.normal_initializer,
            #             mask=None)))
            if self.gate_bias_init == 'chronos':
                # bias_init = -tf.log(
                #     tf.random_uniform(
                #         self.bias_shape, minval=1, maxval=self.timesteps - 1))
                bias_init = -bias_init
            else:
                bias_init = tf.ones(self.bias_shape)
            setattr(  # TODO: smart initialization of these
                self, self.weight_dict['O']['r']['bias'],
                tf.get_variable(name=self.weight_dict['O']['r']['bias'],
                                dtype=self.dtype,
                                trainable=True,
                                initializer=bias_init))

            # Degree of freedom weights (vectors)
            w_shape = [1, 1, 1, self.k]
            b_shape = [1, 1, 1, self.k]
            # w_array = np.ones(w_shape).astype(np.float32)
            # b_array = np.zeros(b_shape).astype(np.float32)

            # Divisive params
            if self.beta and not self.lesion_beta:
                self.beta = tf.get_variable(
                    name='beta',
                    initializer=initialization.xavier_initializer(
                        shape=w_shape,
                        uniform=self.normal_initializer,
                        mask=None))
                # initializer=tf.ones(w_shape, dtype=tf.float32))
            elif self.lesion_beta:
                self.beta = tf.constant(0.)
            else:
                self.beta = tf.constant(1.)

            if self.nu and not self.lesion_nu:
                self.nu = tf.get_variable(
                    name='nu',
                    initializer=initialization.xavier_initializer(
                        shape=b_shape,
                        uniform=self.normal_initializer,
                        mask=None))
                # initializer=tf.zeros(b_shape, dtype=tf.float32))
            elif self.lesion_nu:
                self.nu = tf.constant(0.)
            else:
                self.nu = tf.constant(1.)
            if self.zeta:
                self.zeta = tf.get_variable(
                    name='zeta',
                    initializer=initialization.xavier_initializer(
                        shape=w_shape,
                        uniform=self.normal_initializer,
                        mask=None))
            else:
                self.zeta = tf.constant(1.)
            if self.gamma:
                self.gamma = tf.get_variable(
                    name='gamma',
                    initializer=initialization.xavier_initializer(
                        shape=w_shape,
                        uniform=self.normal_initializer,
                        mask=None))
            else:
                self.gamma = tf.constant(1.)
            # # TODO
            # self.ebias = tf.get_variable(
            #     name='ebias',
            #     initializer=initialization.xavier_initializer(
            #         shape=b_shape,
            #         uniform=self.normal_initializer,
            #         mask=None))

            if self.xi:
                self.xi = tf.get_variable(
                    name='xi',
                    initializer=initialization.xavier_initializer(
                        shape=w_shape,
                        uniform=self.normal_initializer,
                        mask=None))
            else:
                self.xi = tf.constant(1.)
            if self.multiplicative_excitation:
                if self.lesion_kappa:
                    self.kappa = tf.constant(0.)
                else:
                    self.kappa = tf.get_variable(
                        name='kappa',
                        initializer=initialization.xavier_initializer(
                            shape=w_shape,
                            uniform=self.normal_initializer,
                            mask=None))
                    # initializer=tf.zeros(w_shape, dtype=tf.float32) + 0.5)

                if self.lesion_omega:
                    self.omega = tf.constant(0.)
                else:
                    self.omega = tf.get_variable(
                        name='omega',
                        initializer=initialization.xavier_initializer(
                            shape=w_shape,
                            uniform=self.normal_initializer,
                            mask=None))
                    # initializer=tf.zeros(w_shape, dtype=tf.float32) + 0.5)
            else:
                self.kappa = tf.constant(1.)
                self.omega = tf.constant(1.)
            if self.adapation:
                self.rho = tf.get_variable(name='rho',
                                           initializer=tf.ones(
                                               self.timesteps,
                                               dtype=tf.float32))
            if self.lesion_omega:
                self.omega = tf.constant(0.)
            if self.lesion_kappa:
                self.kappa = tf.constant(0.)
            self.lateral_bias = tf.get_variable(
                name='lateral_bias',
                initializer=initialization.xavier_initializer(
                    shape=b_shape, uniform=self.normal_initializer, mask=None))