class model_pure_rtdnn(object):
    def __init__(self, filter_shape, input_shape, rng=None, channel=1, stride=16, n_hids=500, dec_hid = 'tanh'):
        if rng == None:
            rng = np.random.RandomState(23455)
        self.rng = rng
        self.params = []
        self.batch_size = input_shape[0][0]
        self.n_hids = n_hids
        self.stride = stride

        self.filter_shape = filter_shape
        self.input_shape = input_shape
        
        self.enc_proj = DownSampledConvolutionalLayer(nrng=rng, filter_shape=filter_shape[0], input_shape=input_shape[0], stride=stride)
        self.dec_proj = OverSampledConvolutionalLayer(nrng=rng, filter_shape=filter_shape[1], input_shape=input_shape[1], stride=stride)
        
        self.params += self.enc_proj.params
        self.params += self.dec_proj.params

    def cost(self, x):
        rec_x = self.recon(x)
        mse = ((rec_x - x) ** 2).mean()
        return mse
    
    def recon(self, x):
        x_proj = self.proj(x)
        out_x = self.proj_T(x_proj)
        return out_x

    def proj(self, x):
        return self.enc_proj.fprop(x)

    def proj_T(self, x):
        return self.dec_proj.fprop(x)
class model(object):
    def __init__(self, filter_shape, input_shape, rng=None, channel=1, stride=16, n_hids=500, dec_hid = 'tanh'):
        if rng == None:
            rng = np.random.RandomState(23455)
        self.rng = rng
        self.params = []
        self.batch_size = input_shape[0][0]
        self.n_hids = n_hids
        self.stride = stride

        self.filter_shape = filter_shape
        self.input_shape = input_shape
        
        self.enc_proj = DownSampledConvolutionalLayer(nrng=rng, filter_shape=filter_shape[0], input_shape=input_shape[0], stride=stride)
        self.enc = RecursiveConvolutionalLayer(rng,name='enc',activation='rect',n_hids=n_hids,conv_mode='conv')
        self.dec = RecursiveConvolutionalLayer(rng,name='dec',activation='rect',n_hids=n_hids,conv_mode='deconv')
        self.dec_proj = OverSampledConvolutionalLayer(nrng=rng, filter_shape=filter_shape[1], input_shape=input_shape[1], stride=stride)
        
        self.params += self.enc.params
        self.params += self.dec.params
        self.params += self.enc_proj.params
        self.params += self.dec_proj.params

    def cost(self, x):
        rec_x = self.recon(x)
        mse = ((rec_x - x) ** 2).mean()
        return mse
    
    def recon(self, x):
        x_proj = self.proj(x)
        x_re = x_proj.reshape((self.batch_size, self.n_hids, self.input_shape[0][3]/self.stride)) 
        x_re_dim = x_re.dimshuffle(2,0,1) #Change dim. from bc01 to 1bc
        enc_x = self.encode(x_re_dim)
        enc_x_pad = T.zeros_like(x_re_dim)
        enc_x_pad = T.set_subtensor(enc_x_pad[0],enc_x[0])
        rec_x = self.decode(enc_x_pad)
        rec_x = rec_x.dimshuffle(1,2,'x',0) # 1bc to bc01
        out_x = self.proj_T(rec_x)
        return out_x

    def proj(self, x):
        return self.enc_proj.fprop(x)

    def proj_T(self, x):
        return self.dec_proj.fprop(x)
    
    def encode(self, x):
        return self.enc.fprop(x)
        
    def decode(self, x):
        return self.dec.fprop(x)