def __init__(self, hidden_size, output_size=1, n_layers=1, use_cuda=True, attn_type='dot', multi_head=1, context_type=None):
        super(Encoder_Decoder, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.use_cuda = use_cuda
        self.output_size = output_size
        self.bidirectional = True
        self.multi_head = multi_head

        if self.bidirectional:
            self.num_directions = 2
        else:
            self.num_directions = 1

        # encoder
        self.appear_linear = nn.Linear(1024, hidden_size)
        self.feature_linear = nn.Linear(hidden_size+96, hidden_size)
        self.encoder_rnn = nn.ModuleList([nn.GRU(hidden_size, hidden_size, self.n_layers, bidirectional=self.bidirectional) for i in range(80)])
        
        # decoder_rnn
        self.decoder_rnn = nn.ModuleList([nn.GRU(hidden_size, hidden_size, self.n_layers, bidirectional=self.bidirectional) for i in range(80)])
        self.out = nn.Linear(2*hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

        # attention
        self.atten = nn.ModuleList([GlobalAttention(hidden_size*2, attn_type=attn_type) for i in range(self.multi_head)])
        self.reduction = nn.Linear(hidden_size*2*self.multi_head, hidden_size*2)
        
        # self.atten = GlobalAttention(hidden_size*2, attn_type=attn_type)
        # context
        self.context_type = context_type
        if self.context_type is not None:
            self.context_gate = context_gate_factory(self.context_type, hidden_size, hidden_size*2, hidden_size*2, hidden_size*2)
    def __init__(self,
                 hidden_size,
                 output_size=1,
                 n_layers=1,
                 use_cuda=True,
                 attn_type='dot',
                 context_type=None):
        super(Encoder_Decoder, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.use_cuda = use_cuda
        self.output_size = output_size

        # encoder
        self.appear_linear = nn.Linear(1024, hidden_size)
        self.feature_linear = nn.Linear(hidden_size + 96, hidden_size)
        self.encoder_rnn = nn.GRU(hidden_size, hidden_size, bidirectional=True)

        # decoder_rnn
        self.decoder_rnn = nn.GRU(hidden_size, hidden_size, bidirectional=True)
        self.out = nn.Linear(2 * hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

        # attention
        self.atten = GlobalAttention(hidden_size * 2, attn_type=attn_type)
        # context
        self.context_type = context_type
        if self.context_type is not None:
            self.context_gate = context_gate_factory(self.context_type,
                                                     hidden_size,
                                                     hidden_size * 2,
                                                     hidden_size * 2,
                                                     hidden_size * 2)