class PrednetModel(nn.Module):
    """
    Build the Prednet model
    """

    def __init__(self, error_size_list, num_of_layers):
        super(PrednetModel,self).__init__()
        # print "error_size_list", error_size_list
        self.number_of_layers = num_of_layers
        # print "self.number_of_layers", self.number_of_layers
        for layer in range(0, self.number_of_layers):
            
            setattr(self, 'discriminator_' + str(layer + 1), DiscriminativeCell(
                        input_size={'input': IN_LAYER_SIZE[layer], 'state': OUT_LAYER_SIZE[layer]},
                        hidden_size=OUT_LAYER_SIZE[layer],
                        first=(not layer)
                        ))
            
            for d in range(0, DIRECTION):
                setattr(self, 'generator_' + str(layer + 1) + "_" + str(d), GenerativeCell(
                        input_size={'error': ERR_LAYER_SIZE[layer], 'up_state':
                        OUT_LAYER_SIZE[layer + 1] if layer != self.number_of_layers - 1 else 0},
                        hidden_size=OUT_LAYER_SIZE[layer],
                        error_init_size=error_size_list[layer]
                        ))
                    
        
    def forward(self, bottom_up_input, error, state, action_in):

        # generative branch
        up_state = [None] * self.number_of_layers
        
        #self.action = [Attention(10) for count in range(self.number_of_layers-1)]
        self.action = Attention(N_HIDDEN) 
        if torch.cuda.is_available():
            self.action = self.action.cuda()
        
        for layer in reversed(range(0, self.number_of_layers)):
            
             
            if not layer < self.number_of_layers - 1 :
                
                for d in range(0, DIRECTION):
                   
                    state[d][layer] = getattr(self, 'generator_' + str(layer + 1) + "_" + str(d))(
                        error[layer], None, state[d][layer]
                        )
            else:
                  
                for d in range(0, DIRECTION):
                    state[d][layer] = getattr(self, 'generator_' + str(layer + 1) + "_" + str(d))(
                        error[layer], up_state[layer+1], state[d][layer]
                        )
             
            
            #up_state[layer] = self.action[layer-1]([i[layer][0] for i in state], action_in) 
            # print state[layer][0].is_cuda
            up_state[layer] = self.action([i[layer][0] for i in state], action_in) 

        # discriminative branch
        for layer in range(0, self.number_of_layers):
            if layer == 0:
                error[layer] = getattr(self, 'discriminator_' + str(layer + 1))(
                bottom_up_input,
                up_state[layer]
                #state[layer][0]
            )
            else:
                error[layer] = getattr(self, 'discriminator_' + str(layer + 1))(
                error[layer - 1],
                up_state[layer]
            )
        #print up_state[0].size()
        return error, state, up_state