def topk(hm, k=100): ctx = hm.context batch_size, cat, height, width = hm.shape hm = nms(hm) hm = nd.reshape(hm, (0, 0, -1)) topk_scores, topk_idx = nd.topk(hm, k=k, ret_typ='both') topk_x_idx = nd.floor(topk_idx/width) topk_x_idx = nd.reshape(topk_x_idx, (0, -1)) topk_y_idx = (topk_idx%height) topk_y_idx = nd.reshape(topk_y_idx, (0, -1)) topk_scores = nd.reshape(topk_scores, (0, -1)) topk_cat_scores, topk_cat_idx = nd.topk(topk_scores, k=k, ret_typ='both') cls_id = nd.floor(topk_cat_idx/k) batch_idx = nd.repeat(nd.arange(batch_size), repeats=k).reshape((1, -1)) batch_idx = batch_idx.as_in_context(ctx) topk_cat_idx = nd.reshape(topk_cat_idx, (1, -1)) topk_cat_idices = nd.concat(batch_idx, topk_cat_idx, dim=0) topk_cat_x_idx = nd.gather_nd(topk_x_idx, topk_cat_idices) topk_cat_x_idx = nd.reshape(topk_cat_x_idx, (batch_size, k)) topk_cat_y_idx = nd.gather_nd(topk_y_idx, topk_cat_idices) topk_cat_y_idx = nd.reshape(topk_cat_y_idx, (batch_size, k)) return topk_cat_x_idx, topk_cat_y_idx, cls_id
def get_pred_result(hm_pred, offset_pred, wh_pred, k=100): ctx = hm_pred.context batch_size, num_classes, _, _ = hm_pred.shape topk_cat_x_idx, topk_cat_y_idx, cls_id = topk(hm_pred, k=k) batch_index = nd.arange(batch_size) batch_indices = nd.repeat(batch_index, repeats=num_classes) batch_indices = nd.reshape(batch_indices, (1, batch_size*k)) batch_indices = batch_indices.as_in_context(ctx) cls_id = nd.reshape(cls_id, (1, batch_size*k)) topk_cat_y_idx = nd.reshape(topk_cat_y_idx, (1, batch_size*k)) topk_cat_x_idx = nd.reshape(topk_cat_x_idx, (1, batch_size*k)) score_indices = nd.concat(batch_indices, cls_id, topk_cat_y_idx, topk_cat_x_idx, dim=0) scores = nd.gather_nd(hm_pred, score_indices) fake_idx_0 = nd.zeros_like(nd.arange(batch_size*k)).reshape((1, -1)) fake_idx_0 = fake_idx_0.as_in_context(ctx) fake_idx_1 = nd.ones((1, batch_size*k)) fake_idx_1 = fake_idx_1.as_in_context(ctx) fake_indices_0 = nd.concat(batch_indices, fake_idx_0, topk_cat_y_idx, topk_cat_x_idx, dim=0) fake_indices_1 = nd.concat(batch_indices, fake_idx_1, topk_cat_y_idx, topk_cat_x_idx, dim=0) x_offset = nd.gather_nd(offset_pred, fake_indices_0) y_offset = nd.gather_nd(offset_pred, fake_indices_1) h = nd.gather_nd(wh_pred, fake_indices_0) w = nd.gather_nd(wh_pred, fake_indices_1) x_offset_ = nd.broadcast_mul(topk_cat_x_idx, x_offset) y_offset_ = nd.broadcast_mul(topk_cat_y_idx, y_offset) topk_cat_x_idx = nd.broadcast_add(topk_cat_x_idx, x_offset_) topk_cat_y_idx = nd.broadcast_add(topk_cat_y_idx, y_offset_) xmin = topk_cat_x_idx - w/2 ymin = topk_cat_y_idx - h/2 xmax = topk_cat_x_idx + w/2 ymax = topk_cat_y_idx + h/2 xmin = nd.reshape(xmin, (batch_size, k)).expand_dims(axis=-1) ymin = nd.reshape(ymin, (batch_size, k)).expand_dims(axis=-1) xmax = nd.reshape(xmax, (batch_size, k)).expand_dims(axis=-1) ymax = nd.reshape(ymax, (batch_size, k)).expand_dims(axis=-1) cls_id = nd.reshape(cls_id, (batch_size, k)).expand_dims(axis=-1) scores = nd.reshape(scores, (batch_size, k)).expand_dims(axis=-1) results = nd.concat(xmin, ymin, xmax, ymax, cls_id, scores, dim=-1) return results
def _gather_beams(list, beam_indices, batch_size, new_beam_size, cache=None): """Gather beams from nested structure of tensors. Each tensor in nested represents a batch of beams, where beam refers to a single search state (beam search involves searching through multiple states in parallel). This function is used to gather the top beams, specified by beam_indices, from the nested tensors. Args: nested: Nested structure (tensor, list, tuple or dict) containing tensors with shape [batch_size, beam_size, ...]. beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each value in beam_indices must be between [0, beam_size), and are not necessarily unique. batch_size: int size of batch new_beam_size: int number of beams to be pulled from the nested tensors. Returns: Nested structure containing tensors with shape [batch_size, new_beam_size, ...] """ batch_pos = np.arange(0, batch_size * new_beam_size) batch_pos = nd.array(batch_pos, ctx=ctx, dtype='int32') / new_beam_size batch_pos = nd.reshape(batch_pos, (batch_size, new_beam_size)) beam_indices = nd.cast(beam_indices, dtype='int32') coordinates = nd.stack(batch_pos, beam_indices, axis=2) m = coordinates.shape[0] n = coordinates.shape[1] coordinates_tmp = nd.zeros(shape=(m, 2, n), ctx=ctx) for i in xrange(m): coordinates_tmp[i] = coordinates[i].T coordinates_new = nd.ones(shape=(2, m, n), ctx=ctx) for i in xrange(m): coordinates_new[0][i] = coordinates_tmp[i][0] coordinates_new[1][i] = coordinates_tmp[i][1] if cache is None: for i in xrange(len(list)): list[i] = nd.gather_nd(list[i], coordinates_new) return list else: cache = map_structure(lambda t: nd.gather_nd(t, coordinates_new), cache) return cache
def forward(self, x, padding=None): ctx = x.context batch_size = x.shape[0] length = x.shape[1] if padding is not None: # Flattten padding to [batch_size * length] pad_mask = nd.reshape(padding, (-1)) nonpad_ids = nd.array(np.where(pad_mask.asnumpy() < 1e-9), ctx=ctx) # Reshape x to [batch_size*length, hidden_size] to remove padding x = nd.reshape(x, (-1, self.hidden_size)) x = nd.gather_nd(x, indices=nonpad_ids) # Reshape x from 2 dimensions to 3 dimensions x = nd.expand_dims(x, axis=0) output = self.filter_dense_layer(x) if self.train: output = self.dropout(output) output = self.output_dense_layer(output) if padding is not None: output = nd.squeeze(output, axis=0) output = nd.scatter_nd(data=output, indices=nonpad_ids, shape=(batch_size * length, self.hidden_size)) output = nd.reshape(output, shape=(batch_size, length, self.hidden_size)) return output
def _likelihood(self, init, append, connect, end, action_0, actions, iw_ids, log_p_sigma, batch_size, iw_size): # decompose action: action_type, node_type, edge_type, append_pos, connect_pos = \ actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4] _log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + ( 1 - _mask) * nd.zeros_like(_x) # init init = init.reshape([batch_size * iw_size, self.N_A]) index = nd.stack(nd.arange(action_0.shape[0], ctx=action_0.context, dtype='int32'), action_0, axis=0) loss_init = nd.log(nd.gather_nd(init, index) + 1e-10) # end loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32')) # append index = nd.stack(append_pos, node_type, edge_type, axis=0) loss_append = _log_mask(nd.gather_nd(append, index), nd.cast(action_type == 0, 'float32')) # connect index = nd.stack(connect_pos, edge_type, axis=0) loss_connect = _log_mask(nd.gather_nd(connect, index), nd.cast(action_type == 1, 'float32')) # sum up results log_p_x = loss_end + loss_append + loss_connect log_p_x = fn.squeeze( fn.SegmentSumFn(iw_ids, batch_size * iw_size)(fn.unsqueeze(log_p_x, -1)), -1) log_p_x = log_p_x + loss_init # reshape log_p_x = log_p_x.reshape([batch_size, iw_size]) log_p_sigma = log_p_sigma.reshape([batch_size, iw_size]) l = log_p_x - log_p_sigma l = fn.logsumexp(l, axis=1) - math.log(float(iw_size)) return l
def learn(self, experiences, gamma): """Update value parameters using given batch of experience tuples. Params ====== experiences: tuple of (s, a, r, s', done, snake_id, turn_count, snake_health) tuples gamma (float): discount factor """ states, actions, rewards, next_states, dones, snake_id, turn_count, snake_health = experiences # Get max predicted Q values (for next states) from target model with autograd.predict_mode(): if self.qnetwork_target.take_additional_forward_arguments: Q_targets_next = self.qnetwork_target( next_states, snake_id, turn_count, snake_health).max(1).expand_dims(1) else: Q_targets_next = self.qnetwork_target(next_states).max( 1).expand_dims(1) # Compute Q targets for current states dones = dones.astype(np.float32) Q_targets = rewards[:, -1].expand_dims(1) + ( gamma * Q_targets_next * (1 - dones[:, -1].expand_dims(1))) # Get expected Q values from local model last_action = actions[:, -1].expand_dims(1) action_indices = nd.array(np.arange( 0, last_action.shape[0])).as_in_context(ctx) action_indices.attach_grad() last_actions = nd.concat(action_indices.expand_dims(1), last_action, dim=1) with autograd.record(): if self.qnetwork_local.take_additional_forward_arguments: predicted_actions = self.qnetwork_local( states, snake_id, turn_count, snake_health) else: predicted_actions = self.qnetwork_local(states) Q_expected = nd.gather_nd(predicted_actions, last_actions.T) # Compute loss loss = self.loss_function(Q_expected, Q_targets) # Minimize the loss loss.backward() self.trainer.step(Q_expected.shape[0]) # ------------------- update target network ------------------- # self.soft_update(self.tau)
def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum): X_avg = fn.SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast( fn.unsqueeze(NX, 1), 'float32') X_curr = nd.take(X, indices=NX_cum - 1) X = nd.concat(X_avg, X_curr, dim=1) # rnn X = nd.take( X, indices=graph_to_rnn) # batch_size, iw_size, length, num_features batch_size, iw_size, length, num_features = X.shape X = X.reshape([batch_size * iw_size, length, num_features]) X = self.rnn(X) X = X.reshape([batch_size, iw_size, length, -1]) X = nd.gather_nd(X, indices=rnn_to_graph) return X
label = batch.label[0] labels_2d = nd.expand_dims(label, axis=-1) pts_fts = batch.data[0] bs = pts_fts.shape[0] points2 = nd.slice(pts_fts, begin=(0, 0, 0), end=(None, None, 3)) #features2 = nd.slice(pts_fts, begin=(0,0,3), end= (None, None, None)) offset = int(random.gauss(0, setting.sample_num // 8)) offset = max(offset, -setting.sample_num // 4) offset = min(offset, setting.sample_num // 4) sample_num_train = setting.sample_num + offset indices = get_indices(batch_size_train, sample_num_train, point_num) indices_nd = nd.array(indices, dtype=np.int32) points_sampled = nd.gather_nd(points2, indices=indices_nd) #features_sampled = nd.gather_nd(features2, indices=nd.transpose(indices_nd, (2, 0, 1))) xforms_np, rotations_np = get_xforms( batch_size_train, rotation_range=setting.rotation_range, order=setting.order) points_xformed = nd.batch_dot(points_sampled, nd.array(xforms_np), name='points_xformed') points_augmented = augment(points_sampled, nd.array(xforms_np), setting.jitter) features_augmented = None var = mx.sym.var('data', shape=(batch_size_train // len(ctx), sample_num_train,