def __init__(self, opt): super(MixerModel, self).__init__() self.vocab_size = opt.vocab_size self.input_encoding_size = opt.input_encoding_size self.rnn_type = opt.rnn_type self.rnn_size = opt.rnn_size self.num_layers = opt.num_layers self.drop_prob_lm = opt.drop_prob_lm self.seq_length = opt.seq_length self.fc_feat_size = opt.fc_feat_size self.att_feat_size = opt.att_feat_size self.att_size = opt.att_size self.batch_size = 80 # LSTM self.core = LSTM.LSTM_DOUBLE_ATT_TOP(self.input_encoding_size, self.vocab_size + 1, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) # self.vocab_size + 1 -> self.input_encoding_size self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) # (batch_size * fc_feat_size) -> (batch_size * input_encoding_size) self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) self.att_embed = nn.Linear(self.att_feat_size, self.input_encoding_size)
def __init__(self, opt): super(BiShowAttenTellModel, self).__init__() self.vocab_size = opt.vocab_size self.input_encoding_size = opt.input_encoding_size self.rnn_type = opt.rnn_type self.rnn_size = opt.rnn_size self.num_layers = opt.num_layers self.drop_prob_lm = opt.drop_prob_lm self.seq_length = opt.seq_length self.fc_feat_size = opt.fc_feat_size self.att_feat_size = opt.att_feat_size self.att_size = opt.att_size self.output_size = self.vocab_size + 1 # LSTM # self.core = nn.LSTM(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm) if self.rnn_type == "LSTM_SOFT_ATT": self.core = LSTM.LSTM_SOFT_ATT_TOP(self.input_encoding_size, self.output_size, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) self.core1 = LSTM.LSTM_SOFT_ATT_TOP(self.input_encoding_size, self.output_size, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) elif self.rnn_type == "LSTM_DOUBLE_ATT": self.core = LSTM.LSTM_DOUBLE_ATT_TOP(self.input_encoding_size, self.output_size, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) self.core1 = LSTM.LSTM_DOUBLE_ATT_TOP(self.input_encoding_size, self.output_size, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) else: raise Exception("rnn type not supported: {}".format(self.rnn_type)) # self.vocab_size + 1 -> self.input_encoding_size self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) self.img_embed = nn.Linear(self.fc_feat_size, self.rnn_size) self.att_embed = nn.Linear(self.att_feat_size, self.rnn_size) self.proj = nn.Linear(self.rnn_size, self.output_size) # self.relu = nn.RReLU(inplace=True) self.relu = nn.PReLU() self.init_weight()