예제 #1
0
def topk(hm, k=100):

    ctx = hm.context
    
    batch_size, cat, height, width = hm.shape

    hm = nms(hm)
    hm = nd.reshape(hm, (0, 0, -1))
    topk_scores, topk_idx = nd.topk(hm, k=k, ret_typ='both')
    
    topk_x_idx = nd.floor(topk_idx/width)
    topk_x_idx = nd.reshape(topk_x_idx, (0, -1))
    
    topk_y_idx = (topk_idx%height)
    topk_y_idx = nd.reshape(topk_y_idx, (0, -1))
    
    topk_scores = nd.reshape(topk_scores, (0, -1))
    topk_cat_scores, topk_cat_idx = nd.topk(topk_scores, k=k, ret_typ='both')
    
    cls_id = nd.floor(topk_cat_idx/k)
    
    batch_idx = nd.repeat(nd.arange(batch_size), repeats=k).reshape((1, -1))
    batch_idx = batch_idx.as_in_context(ctx)
    topk_cat_idx = nd.reshape(topk_cat_idx, (1, -1))
    topk_cat_idices = nd.concat(batch_idx, topk_cat_idx, dim=0)

    topk_cat_x_idx = nd.gather_nd(topk_x_idx, topk_cat_idices)
    topk_cat_x_idx = nd.reshape(topk_cat_x_idx, (batch_size, k))

    topk_cat_y_idx = nd.gather_nd(topk_y_idx, topk_cat_idices)
    topk_cat_y_idx = nd.reshape(topk_cat_y_idx, (batch_size, k))
    
    return topk_cat_x_idx, topk_cat_y_idx, cls_id
예제 #2
0
def get_pred_result(hm_pred, offset_pred, wh_pred, k=100):
    ctx = hm_pred.context
    batch_size, num_classes, _, _ = hm_pred.shape
    topk_cat_x_idx, topk_cat_y_idx, cls_id = topk(hm_pred, k=k)
    
    batch_index = nd.arange(batch_size)
    batch_indices = nd.repeat(batch_index, repeats=num_classes)
    batch_indices = nd.reshape(batch_indices, (1, batch_size*k))
    batch_indices = batch_indices.as_in_context(ctx)
    
    cls_id = nd.reshape(cls_id, (1, batch_size*k))
    topk_cat_y_idx = nd.reshape(topk_cat_y_idx, (1, batch_size*k))
    topk_cat_x_idx = nd.reshape(topk_cat_x_idx, (1, batch_size*k))
    
    score_indices = nd.concat(batch_indices, cls_id, topk_cat_y_idx, topk_cat_x_idx, dim=0)
    
    scores = nd.gather_nd(hm_pred, score_indices)
    
    fake_idx_0 = nd.zeros_like(nd.arange(batch_size*k)).reshape((1, -1))
    fake_idx_0 = fake_idx_0.as_in_context(ctx)
    fake_idx_1 = nd.ones((1, batch_size*k))
    fake_idx_1 = fake_idx_1.as_in_context(ctx)

    fake_indices_0 = nd.concat(batch_indices, fake_idx_0, topk_cat_y_idx, topk_cat_x_idx, dim=0)
    fake_indices_1 = nd.concat(batch_indices, fake_idx_1, topk_cat_y_idx, topk_cat_x_idx, dim=0)
    x_offset = nd.gather_nd(offset_pred, fake_indices_0)
    y_offset = nd.gather_nd(offset_pred, fake_indices_1)

    h = nd.gather_nd(wh_pred, fake_indices_0)
    w = nd.gather_nd(wh_pred, fake_indices_1)

    x_offset_ = nd.broadcast_mul(topk_cat_x_idx, x_offset)
    y_offset_ = nd.broadcast_mul(topk_cat_y_idx, y_offset)

    topk_cat_x_idx = nd.broadcast_add(topk_cat_x_idx, x_offset_)
    topk_cat_y_idx = nd.broadcast_add(topk_cat_y_idx, y_offset_)

    xmin = topk_cat_x_idx - w/2
    ymin = topk_cat_y_idx - h/2
    xmax = topk_cat_x_idx + w/2
    ymax = topk_cat_y_idx + h/2
    
    xmin = nd.reshape(xmin, (batch_size, k)).expand_dims(axis=-1)
    ymin = nd.reshape(ymin, (batch_size, k)).expand_dims(axis=-1)
    xmax = nd.reshape(xmax, (batch_size, k)).expand_dims(axis=-1)
    ymax = nd.reshape(ymax, (batch_size, k)).expand_dims(axis=-1)
    cls_id = nd.reshape(cls_id, (batch_size, k)).expand_dims(axis=-1)
    scores = nd.reshape(scores, (batch_size, k)).expand_dims(axis=-1)

    results = nd.concat(xmin, ymin, xmax, ymax, cls_id, scores, dim=-1)

    return results
def _gather_beams(list, beam_indices, batch_size, new_beam_size, cache=None):
    """Gather beams from nested structure of tensors.

    Each tensor in nested represents a batch of beams, where beam refers to a
    single search state (beam search involves searching through multiple states
    in parallel).

    This function is used to gather the top beams, specified by
    beam_indices, from the nested tensors.

    Args:
      nested: Nested structure (tensor, list, tuple or dict) containing tensors
        with shape [batch_size, beam_size, ...].
      beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
       value in beam_indices must be between [0, beam_size), and are not
       necessarily unique.
      batch_size: int size of batch
      new_beam_size: int number of beams to be pulled from the nested tensors.

    Returns:
      Nested structure containing tensors with shape
        [batch_size, new_beam_size, ...]
    """
    batch_pos = np.arange(0, batch_size * new_beam_size)
    batch_pos = nd.array(batch_pos, ctx=ctx, dtype='int32') / new_beam_size
    batch_pos = nd.reshape(batch_pos, (batch_size, new_beam_size))
    beam_indices = nd.cast(beam_indices, dtype='int32')

    coordinates = nd.stack(batch_pos, beam_indices, axis=2)
    m = coordinates.shape[0]
    n = coordinates.shape[1]
    coordinates_tmp = nd.zeros(shape=(m, 2, n), ctx=ctx)
    for i in xrange(m):
        coordinates_tmp[i] = coordinates[i].T

    coordinates_new = nd.ones(shape=(2, m, n), ctx=ctx)
    for i in xrange(m):
        coordinates_new[0][i] = coordinates_tmp[i][0]
        coordinates_new[1][i] = coordinates_tmp[i][1]

    if cache is None:
        for i in xrange(len(list)):
            list[i] = nd.gather_nd(list[i], coordinates_new)
        return list
    else:
        cache = map_structure(lambda t: nd.gather_nd(t, coordinates_new),
                              cache)
        return cache
    def forward(self, x, padding=None):
        ctx = x.context
        batch_size = x.shape[0]
        length = x.shape[1]
        if padding is not None:
            # Flattten padding to [batch_size * length]
            pad_mask = nd.reshape(padding, (-1))
            nonpad_ids = nd.array(np.where(pad_mask.asnumpy() < 1e-9), ctx=ctx)

            # Reshape x to [batch_size*length, hidden_size] to remove padding
            x = nd.reshape(x, (-1, self.hidden_size))
            x = nd.gather_nd(x, indices=nonpad_ids)

            # Reshape x from 2 dimensions to 3 dimensions
            x = nd.expand_dims(x, axis=0)

        output = self.filter_dense_layer(x)
        if self.train:
            output = self.dropout(output)
        output = self.output_dense_layer(output)

        if padding is not None:
            output = nd.squeeze(output, axis=0)
            output = nd.scatter_nd(data=output,
                                   indices=nonpad_ids,
                                   shape=(batch_size * length,
                                          self.hidden_size))
            output = nd.reshape(output,
                                shape=(batch_size, length, self.hidden_size))

        return output
예제 #5
0
    def _likelihood(self, init, append, connect, end, action_0, actions,
                    iw_ids, log_p_sigma, batch_size, iw_size):

        # decompose action:
        action_type, node_type, edge_type, append_pos, connect_pos = \
            actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4]
        _log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + (
            1 - _mask) * nd.zeros_like(_x)

        # init
        init = init.reshape([batch_size * iw_size, self.N_A])
        index = nd.stack(nd.arange(action_0.shape[0],
                                   ctx=action_0.context,
                                   dtype='int32'),
                         action_0,
                         axis=0)
        loss_init = nd.log(nd.gather_nd(init, index) + 1e-10)

        # end
        loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32'))

        # append
        index = nd.stack(append_pos, node_type, edge_type, axis=0)
        loss_append = _log_mask(nd.gather_nd(append, index),
                                nd.cast(action_type == 0, 'float32'))

        # connect
        index = nd.stack(connect_pos, edge_type, axis=0)
        loss_connect = _log_mask(nd.gather_nd(connect, index),
                                 nd.cast(action_type == 1, 'float32'))

        # sum up results
        log_p_x = loss_end + loss_append + loss_connect
        log_p_x = fn.squeeze(
            fn.SegmentSumFn(iw_ids,
                            batch_size * iw_size)(fn.unsqueeze(log_p_x, -1)),
            -1)
        log_p_x = log_p_x + loss_init

        # reshape
        log_p_x = log_p_x.reshape([batch_size, iw_size])
        log_p_sigma = log_p_sigma.reshape([batch_size, iw_size])
        l = log_p_x - log_p_sigma
        l = fn.logsumexp(l, axis=1) - math.log(float(iw_size))
        return l
예제 #6
0
    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences: tuple of (s, a, r, s', done, snake_id, turn_count, snake_health) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones, snake_id, turn_count, snake_health = experiences

        # Get max predicted Q values (for next states) from target model
        with autograd.predict_mode():
            if self.qnetwork_target.take_additional_forward_arguments:
                Q_targets_next = self.qnetwork_target(
                    next_states, snake_id, turn_count,
                    snake_health).max(1).expand_dims(1)
            else:
                Q_targets_next = self.qnetwork_target(next_states).max(
                    1).expand_dims(1)

        # Compute Q targets for current states
        dones = dones.astype(np.float32)
        Q_targets = rewards[:, -1].expand_dims(1) + (
            gamma * Q_targets_next * (1 - dones[:, -1].expand_dims(1)))

        # Get expected Q values from local model
        last_action = actions[:, -1].expand_dims(1)
        action_indices = nd.array(np.arange(
            0, last_action.shape[0])).as_in_context(ctx)
        action_indices.attach_grad()
        last_actions = nd.concat(action_indices.expand_dims(1),
                                 last_action,
                                 dim=1)

        with autograd.record():
            if self.qnetwork_local.take_additional_forward_arguments:
                predicted_actions = self.qnetwork_local(
                    states, snake_id, turn_count, snake_health)
            else:
                predicted_actions = self.qnetwork_local(states)

            Q_expected = nd.gather_nd(predicted_actions, last_actions.T)

            # Compute loss
            loss = self.loss_function(Q_expected, Q_targets)

        # Minimize the loss
        loss.backward()
        self.trainer.step(Q_expected.shape[0])

        # ------------------- update target network ------------------- #
        self.soft_update(self.tau)
예제 #7
0
    def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum):
        X_avg = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(
            fn.unsqueeze(NX, 1), 'float32')
        X_curr = nd.take(X, indices=NX_cum - 1)
        X = nd.concat(X_avg, X_curr, dim=1)

        # rnn
        X = nd.take(
            X,
            indices=graph_to_rnn)  # batch_size, iw_size, length, num_features
        batch_size, iw_size, length, num_features = X.shape
        X = X.reshape([batch_size * iw_size, length, num_features])
        X = self.rnn(X)

        X = X.reshape([batch_size, iw_size, length, -1])
        X = nd.gather_nd(X, indices=rnn_to_graph)

        return X
예제 #8
0
        label = batch.label[0]
        labels_2d = nd.expand_dims(label, axis=-1)
        pts_fts = batch.data[0]
        bs = pts_fts.shape[0]
        points2 = nd.slice(pts_fts, begin=(0, 0, 0), end=(None, None, 3))
        #features2 = nd.slice(pts_fts, begin=(0,0,3), end= (None, None, None))

        offset = int(random.gauss(0, setting.sample_num // 8))
        offset = max(offset, -setting.sample_num // 4)
        offset = min(offset, setting.sample_num // 4)
        sample_num_train = setting.sample_num + offset

        indices = get_indices(batch_size_train, sample_num_train, point_num)
        indices_nd = nd.array(indices, dtype=np.int32)
        points_sampled = nd.gather_nd(points2, indices=indices_nd)
        #features_sampled = nd.gather_nd(features2, indices=nd.transpose(indices_nd, (2, 0, 1)))

        xforms_np, rotations_np = get_xforms(
            batch_size_train,
            rotation_range=setting.rotation_range,
            order=setting.order)
        points_xformed = nd.batch_dot(points_sampled,
                                      nd.array(xforms_np),
                                      name='points_xformed')
        points_augmented = augment(points_sampled, nd.array(xforms_np),
                                   setting.jitter)
        features_augmented = None

        var = mx.sym.var('data',
                         shape=(batch_size_train // len(ctx), sample_num_train,