示例#1
0
 def set_hidden_state(self, state):
     if len(state) != len(self.state):
         raise Exception("Provided hidden state array does not match layer hidden states")
     for s, h in zip(self.state, state):
         if s.eval().shape != h.shape:
             raise Exception("Hidden state shape not compatible")
         s.set_value(floatX(h))
示例#2
0
 def set_hidden_state(self, state):
     if len(state) != len(self.state):
         raise Exception(
             "Provided hidden state array does not match layer hidden states"
         )
     for s, h in zip(self.state, state):
         if s.eval().shape != h.shape:
             raise Exception("Hidden state shape not compatible")
         s.set_value(floatX(h))
 def get_hsn_features(self, split, data_key):
     """ Return image features vector for split[data_key]."""
     try:
         return floatX(self.source_dataset[split][data_key]
                       ['final_hidden_features'])
     except KeyError:
         # this image -- description pair doesn't have a source-language
         # vector. Raise a KeyError so the requester can deal with the
         # missing data.
         logger.warning("Skipping '%s' because it doesn't have a source vector", data_key)
         raise KeyError
示例#4
0
 def set_weights(self, weights):
     np = len(self.params)
     nw = len(weights)
     ns = len(self.state)
     if nw == np + ns:
         nw = np
         state = weights[-ns:]
         self.set_hidden_state(state)
     params = self.params[:np]
     weights = weights[:nw]
     assert len(params) == len(weights), 'Provided weight array does not match layer weights (' + \
         str(len(params)) + ' layer params vs. ' + str(len(weights)) + ' provided weights)'
     for p, w in zip(params, weights):
         if p.eval().shape != w.shape:
             raise Exception("Layer shape %s not compatible with weight shape %s." % (p.eval().shape, w.shape))
         p.set_value(floatX(w))
示例#5
0
 def set_weights(self, weights):
     np = len(self.params)
     nw = len(weights)
     ns = len(self.state)
     if nw == np + ns:
         nw = np
         state = weights[-ns:]
         self.set_hidden_state(state)
     params = self.params[:np]
     weights = weights[:nw]
     assert len(params) == len(weights), 'Provided weight array does not match layer weights (' + \
         str(len(params)) + ' layer params vs. ' + str(len(weights)) + ' provided weights)'
     for p, w in zip(params, weights):
         if p.eval().shape != w.shape:
             raise Exception(
                 "Layer shape %s not compatible with weight shape %s." %
                 (p.eval().shape, w.shape))
         p.set_value(floatX(w))
示例#6
0
 def set_weights(self, weights):
     for p, w in zip(self.params, weights):
         p.set_value(floatX(w))
示例#7
0
 def set_weights(self, weights):
     self.running_mean.set_value(floatX(weights[-2]))
     self.running_std.set_value(floatX(weights[-1]))
     super(BatchNormalization, self).set_weights(weights[:-2])
示例#8
0
 def set_state(self, value_list):
     assert len(self.updates) == len(value_list)
     for u, v in zip(self.updates, value_list):
         u[0].set_value(floatX(v))
示例#9
0
 def set_state(self, value_list):
     assert len(self.updates) == len(value_list)
     for u, v in zip(self.updates, value_list):
         u[0].set_value(floatX(v))