class DRAW(object): def __init__(self, imgX, imgY, input = None, n_hidden_enc = 100, n_hidden_dec = 100, n_z=100, n_steps = 8, batch_size = 100, rng = rng): #initialize parameters and if input == None: input = theano.shared(numpy.zeros((batch_size,imgX*imgY))) self.c0 = theano.shared(name='c0', value=numpy.random.uniform(-1.0, 1.0, (imgX*imgY)) .astype(theano.config.floatX)) self.rnn_enc = LSTM(n_hidden_dec+2*imgX*imgY,n_hidden_enc) self.rnn_dec = LSTM(n_z,n_hidden_dec) self.Z = RandomVariable(rng,n_in=n_hidden_enc,n_out=n_z) self.readHead = ReadHead(n_hidden_enc) self.writeHead = WriteHead(imgX,imgY,n_hidden_dec) self.X = RandomVariable(rng,n_in=imgX*imgY,n_out=imgX*imgY,sigmoid_mean=True) self.randSeq = rng.normal((n_steps,batch_size,n_z)) self.params = [self.c0] + self.readHead.params + self.rnn_enc.params + self.Z.params + self.rnn_dec.params + self.X.params + self.writeHead.params #turns vector into n_batches x vector_length matrix #concatenate operation won't broadcast so we add a 0 matrix with #the correct number of rows def vec2Matrix(v): t = v.dimshuffle(['x',0]) t = T.dot(input.dimshuffle([1,0])[0].dimshuffle([0,'x']),t) return v + T.zeros_like(t) def autoEncode(epsilon,ctm1,stm1_enc,htm1_enc,stm1_dec,htm1_dec,ztm1,x): x_err = x - T.nnet.sigmoid(ctm1) rt = self.readHead.read(x,x_err,htm1_dec) [s_t_enc,h_t_enc] = self.rnn_enc.recurrence( T.concatenate([rt,htm1_dec],axis=1),stm1_enc,htm1_enc[-1]) z_t = self.Z.conditional_sample(h_t_enc,epsilon) [s_t_dec,h_t_dec] = self.rnn_dec.recurrence(z_t,stm1_dec,htm1_dec) c_t = ctm1 + self.writeHead.write(h_t_dec) return [c_t,s_t_enc,htm1_enc+[h_t_enc],s_t_dec,htm1_dec,ztm1+[z_t]] c_t,s_t_enc,h_t_enc,s_t_dec,h_t_dec,z_t = [vec2Matrix(self.c0),vec2Matrix(self.rnn_enc.s0), [vec2Matrix(self.rnn_enc.h0)],vec2Matrix(self.rnn_dec.s0), vec2Matrix(self.rnn_dec.h0),[]] #would like to use scan here but runs into errors with computations involving random variables #also takes much longer to find gradient graph for i in range(n_steps): c_t,s_t_enc,h_t_enc,s_t_dec,h_t_dec,z_t = autoEncode(self.randSeq[i],c_t,s_t_enc,h_t_enc,s_t_dec,h_t_dec,z_t,input) def generate(epsilon,ctm1,stm1_dec,htm1_dec): [s_t_dec,h_t_dec] = self.rnn_dec.recurrence(epsilon,stm1_dec,htm1_dec) c_t = ctm1 + self.writeHead.write(h_t_dec) return [c_t,s_t_dec,h_t_dec] c_t2,s_t_dec2,h_t_dec2 = [vec2Matrix(self.c0),vec2Matrix(self.rnn_dec.s0), vec2Matrix(self.rnn_dec.h0)] for i in range(n_steps): c_t2,s_t_dec2,h_t_dec2 = generate(self.randSeq[i],c_t2,s_t_dec2,h_t_dec2) self.h_t_enc = T.stacklists(h_t_enc) self.cT = c_t self.lossX = T.sum(-self.X.log_conditional_prob(input,self.cT)) self.lossZ = T.sum(self.Z.latent_loss(self.h_t_enc)) self.loss = (self.lossX+self.lossZ)/batch_size #diff = (T.dot(self.cT,self.X.w_mean)-input) #var = T.exp(T.dot(self.cT,self.X.w_var)) self.test = self.loss self.generated_x = self.X.conditional_sample(self.cT,rng.normal((batch_size,imgX*imgY))) self.generated_x2 = self.X.conditional_sample(c_t2,rng.normal((batch_size,imgX*imgY))) self.mean = T.dot(self.cT,self.X.w_mean) self.var = T.exp(T.dot(self.cT,self.X.w_var))