class DRAW(object):

  def __init__(self, imgX, imgY, input = None, n_hidden_enc = 100, n_hidden_dec = 100, n_z=100, n_steps = 8, batch_size = 100, rng = rng):

    #initialize parameters and 

    if input == None:
      input = theano.shared(numpy.zeros((batch_size,imgX*imgY)))

    self.c0 = theano.shared(name='c0',
                                value=numpy.random.uniform(-1.0, 1.0,
                                (imgX*imgY))
                                .astype(theano.config.floatX))

    self.rnn_enc = LSTM(n_hidden_dec+2*imgX*imgY,n_hidden_enc)
    self.rnn_dec = LSTM(n_z,n_hidden_dec)
    self.Z = RandomVariable(rng,n_in=n_hidden_enc,n_out=n_z)
    self.readHead = ReadHead(n_hidden_enc)
    self.writeHead = WriteHead(imgX,imgY,n_hidden_dec)
    self.X = RandomVariable(rng,n_in=imgX*imgY,n_out=imgX*imgY,sigmoid_mean=True)
    self.randSeq = rng.normal((n_steps,batch_size,n_z))

    self.params = [self.c0] + self.readHead.params + self.rnn_enc.params + self.Z.params + self.rnn_dec.params + self.X.params + self.writeHead.params

    #turns vector into n_batches x vector_length matrix
    #concatenate operation won't broadcast so we add a 0 matrix with
    #the correct number of rows      
    def vec2Matrix(v):
      t = v.dimshuffle(['x',0])
      t = T.dot(input.dimshuffle([1,0])[0].dimshuffle([0,'x']),t)
      return v + T.zeros_like(t)

    def autoEncode(epsilon,ctm1,stm1_enc,htm1_enc,stm1_dec,htm1_dec,ztm1,x):
      x_err = x - T.nnet.sigmoid(ctm1) 
      rt = self.readHead.read(x,x_err,htm1_dec)
      [s_t_enc,h_t_enc] = self.rnn_enc.recurrence(
                T.concatenate([rt,htm1_dec],axis=1),stm1_enc,htm1_enc[-1])
      z_t = self.Z.conditional_sample(h_t_enc,epsilon)
      [s_t_dec,h_t_dec] = self.rnn_dec.recurrence(z_t,stm1_dec,htm1_dec)
      c_t = ctm1 + self.writeHead.write(h_t_dec)
      return [c_t,s_t_enc,htm1_enc+[h_t_enc],s_t_dec,htm1_dec,ztm1+[z_t]]

    c_t,s_t_enc,h_t_enc,s_t_dec,h_t_dec,z_t = [vec2Matrix(self.c0),vec2Matrix(self.rnn_enc.s0),
          [vec2Matrix(self.rnn_enc.h0)],vec2Matrix(self.rnn_dec.s0),
          vec2Matrix(self.rnn_dec.h0),[]]

    #would like to use scan here but runs into errors with computations involving random variables
    #also takes much longer to find gradient graph

    for i in range(n_steps):
      c_t,s_t_enc,h_t_enc,s_t_dec,h_t_dec,z_t = autoEncode(self.randSeq[i],c_t,s_t_enc,h_t_enc,s_t_dec,h_t_dec,z_t,input)

    def generate(epsilon,ctm1,stm1_dec,htm1_dec):
      [s_t_dec,h_t_dec] = self.rnn_dec.recurrence(epsilon,stm1_dec,htm1_dec)
      c_t = ctm1 + self.writeHead.write(h_t_dec)
      return [c_t,s_t_dec,h_t_dec]

    c_t2,s_t_dec2,h_t_dec2 = [vec2Matrix(self.c0),vec2Matrix(self.rnn_dec.s0),
          vec2Matrix(self.rnn_dec.h0)]

    for i in range(n_steps):
      c_t2,s_t_dec2,h_t_dec2 = generate(self.randSeq[i],c_t2,s_t_dec2,h_t_dec2)


    self.h_t_enc = T.stacklists(h_t_enc)
    self.cT = c_t
    self.lossX = T.sum(-self.X.log_conditional_prob(input,self.cT))
    self.lossZ = T.sum(self.Z.latent_loss(self.h_t_enc))
    self.loss = (self.lossX+self.lossZ)/batch_size
    #diff = (T.dot(self.cT,self.X.w_mean)-input)
    #var = T.exp(T.dot(self.cT,self.X.w_var))
    self.test = self.loss
    self.generated_x = self.X.conditional_sample(self.cT,rng.normal((batch_size,imgX*imgY)))
    self.generated_x2 = self.X.conditional_sample(c_t2,rng.normal((batch_size,imgX*imgY)))
    self.mean = T.dot(self.cT,self.X.w_mean)
    self.var = T.exp(T.dot(self.cT,self.X.w_var))