def __init__(self, opt): super(SCSTModel, self).__init__() self.vocab_size = opt.vocab_size self.input_encoding_size = opt.input_encoding_size self.rnn_type = opt.rnn_type self.rnn_size = opt.rnn_size self.num_layers = opt.num_layers self.drop_prob_lm = opt.drop_prob_lm self.seq_length = opt.seq_length self.fc_feat_size = opt.fc_feat_size self.att_feat_size = opt.att_feat_size self.att_size = opt.att_size self.batch_size = opt.batch_size * opt.seq_per_img self.rnn_atten = opt.rnn_atten # LSTM if self.rnn_atten == "ATT_LSTM": self.atten = LSTM.LSTM_ATTEN_LAYER(self.rnn_size) # LSTM if self.rnn_type == "LSTM": self.core = LSTM.LSTM(self.input_encoding_size * 2, self.vocab_size + 1, self.rnn_size, dropout=self.drop_prob_lm) elif self.rnn_type == "LSTM_SOFT_ATT": self.core = LSTM.LSTM_SOFT_ATT(self.input_encoding_size * 2, self.vocab_size + 1, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) elif self.rnn_type == "LSTM_DOUBLE_ATT": self.core = LSTM.LSTM_DOUBLE_ATT(self.input_encoding_size * 2, self.vocab_size + 1, self.rnn_size, self.att_size, dropout=self.drop_prob_lm) # self.vocab_size + 1 -> self.input_encoding_size self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) self.embed_tc = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) # (batch_size * fc_feat_size) -> (batch_size * input_encoding_size) self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) self.att_embed = nn.Linear(self.att_feat_size, self.input_encoding_size) self.init_weights()
def get_lstm(opt): print('rnn_type ', opt.rnn_type) # LSTM if opt.rnn_type == "LSTM": core = LSTM.LSTM(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT": core = LSTM.LSTM_SOFT_ATT(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.att_size, opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT": core = LSTM.LSTM_DOUBLE_ATT(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.att_size, opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK": core = LSTM.LSTM_SOFT_ATT_STACK(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK": core = LSTM.LSTM_DOUBLE_ATT_STACK(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL": core = LSTM.LSTM_DOUBLE_ATT_STACK_PARALLEL(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_POLICY": core = LSTM.LSTM_DOUBLE_ATT_STACK_PARALLEL_POLICY(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_BN": core = LSTM.LSTM_DOUBLE_ATT_STACK_PARALLEL_BN(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_BN_RELU": core = LSTM.LSTM_DOUBLE_ATT_STACK_PARALLEL_BN_RELU(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_DROPOUT": core = LSTM.LSTM_DOUBLE_ATT_STACK_PARALLEL_DROPOUT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_DROPOUT_SET": core = LSTM.LSTM_DOUBLE_ATT_STACK_PARALLEL_DROPOUT_SET(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.rnn_size_list, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "GRU_DOUBLE_ATT_STACK_PARALLEL_DROPOUT": core = GRU.GRU_DOUBLE_ATT_STACK_PARALLEL_DROPOUT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_IT_ATT": core = LSTM1.LSTM_IT_ATT(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.att_size, opt.drop_prob_lm, opt.num_layers, opt.word_input_layer, opt.att_input_layer) elif opt.rnn_type == "LSTM_IT_ATT_COMBINE": core = LSTM1.LSTM_IT_ATT_COMBINE(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.att_size, opt.drop_prob_lm, opt.num_layers, opt.word_input_layer, opt.att_input_layer) elif opt.rnn_type == "FO_IT_ATT_COMBINE": core = LSTM1.FO_IT_ATT_COMBINE(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.att_size, opt.drop_prob_lm, opt.num_layers, opt.word_input_layer, opt.att_input_layer) elif opt.rnn_type == "CONV_IT_ATT_COMBINE": core = LSTM1.CONV_IT_ATT_COMBINE(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.att_size, opt.drop_prob_lm, opt.num_layers, opt.word_input_layer, opt.att_input_layer) elif opt.rnn_type == "CONV_LSTM": core = LSTM1.CONV_LSTM(opt.input_encoding_size, opt.vocab_size + 1, opt.rnn_size, opt.drop_prob_lm, opt.num_layers, opt.block_num, opt.use_proj_mul) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_NEW": core = LSTM1.LSTM_DOUBLE_ATT_STACK_PARALLEL(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT": core = LSTM1.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_NEW": core = LSTM1.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_NEW(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT": core = LSTM2.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_WITH_BU": core = LSTM2.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_WITH_BU(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.bu_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_NEW": core = LSTM2.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_NEW(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_LSTM_MUL": core = LSTM2.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_LSTM_MUL(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.drop_prob_lm, opt.block_num) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_A": core = LSTM2.LSTM_DOUBLE_ATT_STACK_PARALLEL_A(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL": core = LSTM2.LSTM_SOFT_ATT_STACK_PARALLEL(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_WITH_WEIGHT": core = LSTM2.LSTM_SOFT_ATT_STACK_PARALLEL_WITH_WEIGHT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_WITH_MUL_WEIGHT": core = LSTM2.LSTM_SOFT_ATT_STACK_PARALLEL_WITH_MUL_WEIGHT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_WITH_WEIGHT": core = LSTM2.LSTM_DOUBLE_ATT_STACK_PARALLEL_MUL_OUT_ATT_WITH_WEIGHT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_WITH_WEIGHT_SPP": core = LSTM3.LSTM_SOFT_ATT_STACK_PARALLEL_WITH_WEIGHT_SPP(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.pool_size, opt.spp_num, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_SPP": core = LSTM3.LSTM_SOFT_ATT_STACK_PARALLEL_SPP(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.pool_size, opt.spp_num, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_MEMORY": core = LSTM4.LSTM_SOFT_ATT_STACK_PARALLEL_MEMORY(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.memory_num_hop, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_NO_MEMORY": core = LSTM4.LSTM_SOFT_ATT_STACK_PARALLEL_NO_MEMORY(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_WITH_WEIGHT_BU": core = LSTM5.LSTM_SOFT_ATT_STACK_PARALLEL_WITH_WEIGHT_BU(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.bu_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_C_S_ATT_STACK_PARALLEL_WITH_WEIGHT_BU": core = LSTM5.LSTM_C_S_ATT_STACK_PARALLEL_WITH_WEIGHT_BU(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.bu_size, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_WITH_TOP_DOWN_ATTEN": core = LSTM6.LSTM_WITH_TOP_DOWN_ATTEN(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, opt.bu_size, opt.bu_num, dropout=opt.drop_prob_lm) elif opt.rnn_type == "LSTM_SOFT_ATT_STACK_PARALLEL_WITH_FC_WEIGHT": core = LSTM2.LSTM_SOFT_ATT_STACK_PARALLEL_WITH_FC_WEIGHT(opt.input_encoding_size, opt.vocab_size + 1, opt.num_layers, opt.num_parallels, opt.rnn_size, opt.att_size, dropout=opt.drop_prob_lm) else: raise Exception("rnn type not supported: {}".format(opt.rnn_type)) return core