def call(self, argument, mask=None): """Execute this layer on input tensors. Parameters ---------- argument: list List of two tensors (X, Xp). X should be of shape (n_test, n_feat) and Xp should be of shape (n_support, n_feat) where n_test is the size of the test set, n_support that of the support set, and n_feat is the number of per-atom features. Returns ------- list Returns two tensors of same shape as input. Namely the output shape will be [(n_test, n_feat), (n_support, n_feat)] """ self.build() x, xp = argument # Get initializations p = self.p_init q = self.q_init # Rename support z = xp states = self.support_states_init x_states = self.test_states_init for d in range(self.max_depth): # Process support xp using attention e = cos(z + q, xp) a = tf.nn.softmax(e) # Get linear combination of support set r = model_ops.dot(a, xp) # Not sure if it helps to place the update here or later yet. Will # decide #z = r # Process test x using attention x_e = cos(x + p, z) x_a = tf.nn.softmax(x_e) s = model_ops.dot(x_a, z) # Generate new support attention states qr = model_ops.concatenate([q, r], axis=1) q, states = self.support_lstm([qr] + states) # Generate new test attention states ps = model_ops.concatenate([p, s], axis=1) p, x_states = self.test_lstm([ps] + x_states) # Redefine z = r #return [x+p, z+q] return [x + p, xp + q]
def call(self, x_xp, mask=None): """Execute this layer on input tensors. Parameters ---------- x_xp: list List of two tensors (X, Xp). X should be of shape (n_test, n_feat) and Xp should be of shape (n_support, n_feat) where n_test is the size of the test set, n_support that of the support set, and n_feat is the number of per-atom features. Returns ------- list Returns two tensors of same shape as input. Namely the output shape will be [(n_test, n_feat), (n_support, n_feat)] """ # x is test set, xp is support set. x, xp = x_xp ## Initializes trainable weights. n_feat = self.n_feat self.lstm = LSTMStep(n_feat, 2 * n_feat) self.q_init = model_ops.zeros([self.n_test, n_feat]) self.r_init = model_ops.zeros([self.n_test, n_feat]) self.states_init = self.lstm.get_initial_states([self.n_test, n_feat]) self.trainable_weights = [self.q_init, self.r_init] ### Performs computations # Get initializations q = self.q_init #r = self.r_init states = self.states_init for d in range(self.max_depth): # Process using attention # Eqn (4), appendix A.1 of Matching Networks paper e = cos(x + q, xp) a = tf.nn.softmax(e) r = model_ops.dot(a, xp) # Generate new aattention states y = model_ops.concatenate([q, r], axis=1) q, states = self.lstm([y] + states) #+ self.lstm.get_constants(x) return [x + q, xp]