def __init__(self, dict_args): super(SequenceEncoder, self).__init__() self.configuration = dict_args['encoder_configuration'] self.dropout_rate = dict_args['encoder_dropout_rate'] self.input_dim = dict_args['encoder_input_dim'] self.contextvector = None self.inputvector = None self.recurrent, self.attention = None, None self.dropout_layer = nn.Dropout(p=self.dropout_rate) #self.linear = nn.Linear(self.input_dim, self.input_dim) if self.configuration == 'MP': self.recurrent, self.attention = False, False elif self.configuration == 'LSTM': self.recurrent, self.attention = True, False elif self.configuration == 'MPAttn': self.recurrent, self.attention = False, True elif self.configuration == 'LSTMAttn': self.recurrent, self.attention = True, True if self.recurrent: self.hidden_dim = dict_args['encoder_rnn_hdim'] self.rnn_type = dict_args['encoder_rnn_type'] self.num_layers = dict_args['encoder_num_layers'] recurrent_layer_args = { 'input_dim': self.input_dim, 'rnn_hdim': self.hidden_dim, 'rnn_type': self.rnn_type, 'num_layers': self.num_layers, 'dropout_rate': self.dropout_rate } self.recurrent_layer = BiDirEncoder(recurrent_layer_args) if self.attention: self.projection_dim = dict_args['encoderattn_projection_dim'] if self.configuration == 'LSTMAttn': self.context_dim = self.hidden_dim elif self.configuration == 'MPAttn': self.context_dim = self.input_dim self.query_dim = dict_args['encoderattn_query_dim'] similarity_function_args = { 'sequence1_dim': self.context_dim, 'sequence2_dim': self.query_dim, 'projection_dim': self.projection_dim } self.similarity_function = similarity.LinearProjectionSimilarity( similarity_function_args) self.attention_function = UniDirAttention( {'similarity_function': self.similarity_function})
def __init__(self, dict_args): super(UniDirAttention, self).__init__() self.similarity_function_name = dict_args['similarity_function'] #dict_args should contain the arguments for the similarity function as well self.similarity_function = None if self.similarity_function_name == 'DotProduct': self.similarity_function = similarity.DotProductSimilarity( dict_args) elif self.similarity_function_name == 'WeightedInputsConcatenation': self.similarity_function = similarity.LinearConcatenationSimilarity( dict_args) elif self.similarity_function_name == 'WeightedInputsDotConcatenation': self.similarity_function = similarity.LinearConcatenationDotSimilarity( dict_args) elif self.similarity_function_name == 'WeightedSumProjection': self.similarity_function = similarity.LinearProjectionSimilarity( dict_args) elif self.similarity_function_name == 'ProjectionSimilaritySharedWeights': self.similarity_function = dict_args['similarity_function_pointer']
def __init__(self, dict_args): super(BiDirAttention, self).__init__() self.similarity_function_name = dict_args['similarity_function'] #attends a sequence(instead of word) over an other sequence self.one_shot_attention = dict_args['one_shot_attention'] self.self_match_attention = dict_args['self_match_attention'] self.similarity_function = None if self.similarity_function_name == 'DotProduct': self.similarity_function = similarity.DotProductSimilarity( dict_args) elif self.similarity_function_name == 'WeightedInputsConcatenation': self.similarity_function = similarity.LinearConcatenationSimilarity( dict_args) elif self.similarity_function_name == 'WeightedInputsDotConcatenation': self.similarity_function = similarity.LinearConcatenationDotSimilarity( dict_args) elif self.similarity_function_name == 'WeightedSumProjection': self.similarity_function = similarity.LinearProjectionSimilarity( dict_args) elif self.similarity_function_name == 'ProjectionSimilaritySharedWeights': self.similarity_function = dict_args['similarity_function_pointer']
def __init__(self, dict_args): super(VideoFrameEncoder, self).__init__() self.configuration = dict_args['encoder_configuration'] self.dropout_rate = dict_args['encoder_dropout_rate'] self.channel_dim = dict_args['frame_channel_dim'] self.spatial_dim = dict_args['frame_spatial_dim'] self.projection_dim = dict_args['encoderattn_projection_dim'] self.use_linear = dict_args["encoder_linear"] if self.use_linear: self.channelred_dim = dict_args["frame_channelred_dim"] self.linear = nn.Linear(self.channel_dim, self.channelred_dim) self.channel_dim = self.channelred_dim self.recurrent = None self.spatialattention = None self.temporalattention = None self.dropout_layer = nn.Dropout(p=self.dropout_rate) if self.configuration == 'LSTMTrackSpatial': self.recurrent, self.spatialattention, self.temporalattention = True, True, False elif self.configuration == 'LSTMTrackSpatialTemporal': self.recurrent, self.spatialattention, self.temporalattention = True, True, True elif self.configuration == 'SpatialTemporal': self.recurrent, self.spatialattention, self.temporalattention = False, True, True if self.recurrent: self.hidden_dim = dict_args['encoder_rnn_hdim'] self.rnn_type = dict_args['encoder_rnn_type'] if self.rnn_type == 'LSTM': self.trackrnn = nn.LSTMCell(self.channel_dim, self.hidden_dim) elif self.rnn_type == 'GRU': self.trackrnn = nn.GRUCell(self.channel_dim, self.hidden_dim) elif self.rnn_type == 'RNN': pass if self.spatialattention: self.query_dim = dict_args['encoderattn_query_dim'] spatial_similarity_function_args = { 'sequence1_dim' : self.channel_dim, 'sequence2_dim' : self.query_dim, 'projection_dim' : self.projection_dim } self.spatial_similarity_function = similarity.LinearProjectionSimilarity(spatial_similarity_function_args) self.spatial_attention_function = UniDirAttention({'similarity_function': self.spatial_similarity_function}) if self.temporalattention: self.query_dim = dict_args['encoderattn_query_dim'] self.context_dim = self.channel_dim if self.recurrent: self.context_dim = self.hidden_dim temporal_similarity_function_args = { 'sequence1_dim' : self.context_dim, 'sequence2_dim' : self.query_dim, 'projection_dim' : self.projection_dim } self.temporal_similarity_function = similarity.LinearProjectionSimilarity(temporal_similarity_function_args) self.temporal_attention_function = UniDirAttention({'similarity_function': self.temporal_similarity_function})