def broadcast_params(self, optim_result): """ Apply Broadcast operations in the sequential order of parameter groups. Returns: bool, the status flag. """ param_group = [] key_group = [] for _ in range(self.dev_num): param_group.append(F.make_tuple()) key_group.append(F.make_tuple()) for i in range(self.param_length): param_group[self.param_rank[i]] = param_group[ self.param_rank[i]] + (self.parameters[i], ) key = P.MakeRefKey(self.param_names[i])() key_group[ self.param_rank[i]] = key_group[self.param_rank[i]] + (key, ) new_param_group = [] for root in range(self.dev_num): ops = P.Broadcast(root) next_params = ops(param_group[root]) new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) status = F.control_depend(optim_result, new_param_group[0][0]) for i in range(self.dev_num - 1): status = F.depend( F.control_depend(new_param_group[i], new_param_group[i + 1][0]), status) return status
def construct(self, *inputs): if self.output_num == 1: return self.reduce_sum(self.network(*inputs), None) ret = F.make_tuple() for index in range(self.output_num): predict = self.network(*inputs)[index] predict_reduce = self.reduce_sum(predict, None) ret = ret + F.make_tuple(predict_reduce) return ret
def construct6(self, x1, x2, x3, x4, x5, x6): ret = F.make_tuple() predict = self.network(x1, x2, x3, x4, x5, x6) if self.num_output == 1 and self.output_is_tuple == 0: return predict * self.cast(self.one, self.dtype(predict)) for i in range(self.num_output): ret = ret + F.make_tuple( predict[i] * self.cast(self.one, self.dtype(predict[i]))) return ret