Ejemplo n.º 1
0
    def _rnn_test(self, X, NX, NX_rep, NX_cum, h):
        # note: one partition for one molecule
        X_avg = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(
            fn.unsqueeze(NX, 1), 'float32')
        X_curr = nd.take(X, indices=NX_cum - 1)
        X = nd.concat(X_avg, X_curr, dim=1)  # size: [NX, F_in * 2]

        # rnn
        X = fn.unsqueeze(X, axis=1)
        X, h = self.rnn(X, h)

        X = fn.squeeze(X, axis=1)
        return X, h
Ejemplo n.º 2
0
    def forward(self, *input):
        if self.mode == 'loss' or self.mode == 'likelihood':
            X, A, iw_ids, last_append_mask, \
            NX, NX_rep, action_0, actions, log_p, \
            batch_size, iw_size, \
            graph_to_rnn, rnn_to_graph, NX_cum, \
            c, ids = input

            init = nd.tile(fn.unsqueeze(self._policy_0(c), axis=1),
                           [1, iw_size, 1])
            append, connect, end = self._policy(X, A, NX, NX_rep,
                                                last_append_mask, graph_to_rnn,
                                                rnn_to_graph, NX_cum, c, ids)
            l = self._likelihood(init, append, connect, end, action_0, actions,
                                 iw_ids, log_p, batch_size, iw_size)
            if self.mode == 'likelihood':
                return l
            else:
                return -l.mean()
        elif self.mode == 'decode_0':
            return self._policy_0(*input)
        elif self.mode == 'decode_step':
            X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids = input
            return self._decode_step(X, A, NX, NX_rep, last_append_mask,
                                     NX_cum, h, c, ids)
        else:
            raise ValueError
Ejemplo n.º 3
0
    def forward(self, X, NX, NX_rep, X_end=None):
        # segment mean for X
        if X_end is None:
            X_end = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(
                fn.unsqueeze(NX, 1), 'float32')
        X = nd.concat(X, X_end[NX_rep, :], dim=1)

        X_h = nd.relu(self.linear_h(X)).reshape([-1, self.F_h])
        X_h_end = nd.relu(self.linear_h_t(X_end)).reshape([-1, self.F_h])

        X_x = nd.exp(self.linear_x(X_h)).reshape(
            [-1, self.k, self.N_B + self.N_B * self.N_A])
        X_x_end = nd.exp(self.linear_x_t(X_h_end)).reshape([-1, self.k, 1])

        X_sum = nd.sum(fn.SegmentSumFn(NX_rep, NX.shape[0])(X_x),
                       -1,
                       keepdims=True) + X_x_end
        X_sum_gathered = X_sum[NX_rep, :, :]

        X_softmax = X_x / X_sum_gathered
        X_softmax_end = X_x_end / X_sum

        if self.k > 1:
            pi = fn.unsqueeze(nd.softmax(self.linear_pi(X_end), axis=1), -1)
            pi_gathered = pi[NX_rep, :, :]

            X_softmax = nd.sum(X_softmax * pi_gathered, axis=1)
            X_softmax_end = nd.sum(X_softmax_end * pi, axis=1)
        else:
            X_softmax = fn.squeeze(X_softmax, 1)
            X_softmax_end = fn.squeeze(X_softmax_end, 1)

        # generate output
        connect, append = X_softmax[:, :self.N_B], X_softmax[:, self.N_B:]
        append = append.reshape([-1, self.N_A, self.N_B])
        end = fn.squeeze(X_softmax_end, -1)

        return append, connect, end
Ejemplo n.º 4
0
    def _likelihood(self, init, append, connect, end, action_0, actions,
                    iw_ids, log_p_sigma, batch_size, iw_size):

        # decompose action:
        action_type, node_type, edge_type, append_pos, connect_pos = \
            actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4]
        _log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + (
            1 - _mask) * nd.zeros_like(_x)

        # init
        init = init.reshape([batch_size * iw_size, self.N_A])
        index = nd.stack(nd.arange(action_0.shape[0],
                                   ctx=action_0.context,
                                   dtype='int32'),
                         action_0,
                         axis=0)
        loss_init = nd.log(nd.gather_nd(init, index) + 1e-10)

        # end
        loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32'))

        # append
        index = nd.stack(append_pos, node_type, edge_type, axis=0)
        loss_append = _log_mask(nd.gather_nd(append, index),
                                nd.cast(action_type == 0, 'float32'))

        # connect
        index = nd.stack(connect_pos, edge_type, axis=0)
        loss_connect = _log_mask(nd.gather_nd(connect, index),
                                 nd.cast(action_type == 1, 'float32'))

        # sum up results
        log_p_x = loss_end + loss_append + loss_connect
        log_p_x = fn.squeeze(
            fn.SegmentSumFn(iw_ids,
                            batch_size * iw_size)(fn.unsqueeze(log_p_x, -1)),
            -1)
        log_p_x = log_p_x + loss_init

        # reshape
        log_p_x = log_p_x.reshape([batch_size, iw_size])
        log_p_sigma = log_p_sigma.reshape([batch_size, iw_size])
        l = log_p_x - log_p_sigma
        l = fn.logsumexp(l, axis=1) - math.log(float(iw_size))
        return l
Ejemplo n.º 5
0
    def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum):
        X_avg = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(
            fn.unsqueeze(NX, 1), 'float32')
        X_curr = nd.take(X, indices=NX_cum - 1)
        X = nd.concat(X_avg, X_curr, dim=1)

        # rnn
        X = nd.take(
            X,
            indices=graph_to_rnn)  # batch_size, iw_size, length, num_features
        batch_size, iw_size, length, num_features = X.shape
        X = X.reshape([batch_size * iw_size, length, num_features])
        X = self.rnn(X)

        X = X.reshape([batch_size, iw_size, length, -1])
        X = nd.gather_nd(X, indices=rnn_to_graph)

        return X
Ejemplo n.º 6
0
 def unsqueeze(self, *args, **kwargs):
     return F.unsqueeze(self, *args, **kwargs)