def set_hidden_state(self, state): if len(state) != len(self.state): raise Exception("Provided hidden state array does not match layer hidden states") for s, h in zip(self.state, state): if s.eval().shape != h.shape: raise Exception("Hidden state shape not compatible") s.set_value(floatX(h))
def set_hidden_state(self, state): if len(state) != len(self.state): raise Exception( "Provided hidden state array does not match layer hidden states" ) for s, h in zip(self.state, state): if s.eval().shape != h.shape: raise Exception("Hidden state shape not compatible") s.set_value(floatX(h))
def get_hsn_features(self, split, data_key): """ Return image features vector for split[data_key].""" try: return floatX(self.source_dataset[split][data_key] ['final_hidden_features']) except KeyError: # this image -- description pair doesn't have a source-language # vector. Raise a KeyError so the requester can deal with the # missing data. logger.warning("Skipping '%s' because it doesn't have a source vector", data_key) raise KeyError
def set_weights(self, weights): np = len(self.params) nw = len(weights) ns = len(self.state) if nw == np + ns: nw = np state = weights[-ns:] self.set_hidden_state(state) params = self.params[:np] weights = weights[:nw] assert len(params) == len(weights), 'Provided weight array does not match layer weights (' + \ str(len(params)) + ' layer params vs. ' + str(len(weights)) + ' provided weights)' for p, w in zip(params, weights): if p.eval().shape != w.shape: raise Exception("Layer shape %s not compatible with weight shape %s." % (p.eval().shape, w.shape)) p.set_value(floatX(w))
def set_weights(self, weights): np = len(self.params) nw = len(weights) ns = len(self.state) if nw == np + ns: nw = np state = weights[-ns:] self.set_hidden_state(state) params = self.params[:np] weights = weights[:nw] assert len(params) == len(weights), 'Provided weight array does not match layer weights (' + \ str(len(params)) + ' layer params vs. ' + str(len(weights)) + ' provided weights)' for p, w in zip(params, weights): if p.eval().shape != w.shape: raise Exception( "Layer shape %s not compatible with weight shape %s." % (p.eval().shape, w.shape)) p.set_value(floatX(w))
def set_weights(self, weights): for p, w in zip(self.params, weights): p.set_value(floatX(w))
def set_weights(self, weights): self.running_mean.set_value(floatX(weights[-2])) self.running_std.set_value(floatX(weights[-1])) super(BatchNormalization, self).set_weights(weights[:-2])
def set_state(self, value_list): assert len(self.updates) == len(value_list) for u, v in zip(self.updates, value_list): u[0].set_value(floatX(v))