def _rnn_test(self, X, NX, NX_rep, NX_cum, h): # note: one partition for one molecule X_avg = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast( fn.unsqueeze(NX, 1), 'float32') X_curr = nd.take(X, indices=NX_cum - 1) X = nd.concat(X_avg, X_curr, dim=1) # size: [NX, F_in * 2] # rnn X = fn.unsqueeze(X, axis=1) X, h = self.rnn(X, h) X = fn.squeeze(X, axis=1) return X, h
def forward(self, X, NX, NX_rep, X_end=None): # segment mean for X if X_end is None: X_end = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast( fn.unsqueeze(NX, 1), 'float32') X = nd.concat(X, X_end[NX_rep, :], dim=1) X_h = nd.relu(self.linear_h(X)).reshape([-1, self.F_h]) X_h_end = nd.relu(self.linear_h_t(X_end)).reshape([-1, self.F_h]) X_x = nd.exp(self.linear_x(X_h)).reshape( [-1, self.k, self.N_B + self.N_B * self.N_A]) X_x_end = nd.exp(self.linear_x_t(X_h_end)).reshape([-1, self.k, 1]) X_sum = nd.sum(fn.SegmentSumFn(NX_rep, NX.shape[0])(X_x), -1, keepdims=True) + X_x_end X_sum_gathered = X_sum[NX_rep, :, :] X_softmax = X_x / X_sum_gathered X_softmax_end = X_x_end / X_sum if self.k > 1: pi = fn.unsqueeze(nd.softmax(self.linear_pi(X_end), axis=1), -1) pi_gathered = pi[NX_rep, :, :] X_softmax = nd.sum(X_softmax * pi_gathered, axis=1) X_softmax_end = nd.sum(X_softmax_end * pi, axis=1) else: X_softmax = fn.squeeze(X_softmax, 1) X_softmax_end = fn.squeeze(X_softmax_end, 1) # generate output connect, append = X_softmax[:, :self.N_B], X_softmax[:, self.N_B:] append = append.reshape([-1, self.N_A, self.N_B]) end = fn.squeeze(X_softmax_end, -1) return append, connect, end
def _likelihood(self, init, append, connect, end, action_0, actions, iw_ids, log_p_sigma, batch_size, iw_size): # decompose action: action_type, node_type, edge_type, append_pos, connect_pos = \ actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4] _log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + ( 1 - _mask) * nd.zeros_like(_x) # init init = init.reshape([batch_size * iw_size, self.N_A]) index = nd.stack(nd.arange(action_0.shape[0], ctx=action_0.context, dtype='int32'), action_0, axis=0) loss_init = nd.log(nd.gather_nd(init, index) + 1e-10) # end loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32')) # append index = nd.stack(append_pos, node_type, edge_type, axis=0) loss_append = _log_mask(nd.gather_nd(append, index), nd.cast(action_type == 0, 'float32')) # connect index = nd.stack(connect_pos, edge_type, axis=0) loss_connect = _log_mask(nd.gather_nd(connect, index), nd.cast(action_type == 1, 'float32')) # sum up results log_p_x = loss_end + loss_append + loss_connect log_p_x = fn.squeeze( fn.SegmentSumFn(iw_ids, batch_size * iw_size)(fn.unsqueeze(log_p_x, -1)), -1) log_p_x = log_p_x + loss_init # reshape log_p_x = log_p_x.reshape([batch_size, iw_size]) log_p_sigma = log_p_sigma.reshape([batch_size, iw_size]) l = log_p_x - log_p_sigma l = fn.logsumexp(l, axis=1) - math.log(float(iw_size)) return l
def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum): X_avg = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast( fn.unsqueeze(NX, 1), 'float32') X_curr = nd.take(X, indices=NX_cum - 1) X = nd.concat(X_avg, X_curr, dim=1) # rnn X = nd.take( X, indices=graph_to_rnn) # batch_size, iw_size, length, num_features batch_size, iw_size, length, num_features = X.shape X = X.reshape([batch_size * iw_size, length, num_features]) X = self.rnn(X) X = X.reshape([batch_size, iw_size, length, -1]) X = nd.gather_nd(X, indices=rnn_to_graph) return X