Пример #1
0
    def broadcast_params(self, optim_result):
        """
        Apply Broadcast operations in the sequential order of parameter groups.

        Returns:
             bool, the status flag.
        """
        param_group = []
        key_group = []
        for _ in range(self.dev_num):
            param_group.append(F.make_tuple())
            key_group.append(F.make_tuple())
        for i in range(self.param_length):
            param_group[self.param_rank[i]] = param_group[
                self.param_rank[i]] + (self.parameters[i], )
            key = P.MakeRefKey(self.param_names[i])()
            key_group[
                self.param_rank[i]] = key_group[self.param_rank[i]] + (key, )
        new_param_group = []
        for root in range(self.dev_num):
            ops = P.Broadcast(root)
            next_params = ops(param_group[root])
            new_param_group.append(next_params)
            for i in range(F.tuple_len(next_params)):
                F.assign(key_group[root][i], next_params[i])
        status = F.control_depend(optim_result, new_param_group[0][0])
        for i in range(self.dev_num - 1):
            status = F.depend(
                F.control_depend(new_param_group[i],
                                 new_param_group[i + 1][0]), status)

        return status
Пример #2
0
 def construct(self, *inputs):
     if self.output_num == 1:
         return self.reduce_sum(self.network(*inputs), None)
     ret = F.make_tuple()
     for index in range(self.output_num):
         predict = self.network(*inputs)[index]
         predict_reduce = self.reduce_sum(predict, None)
         ret = ret + F.make_tuple(predict_reduce)
     return ret
Пример #3
0
 def construct6(self, x1, x2, x3, x4, x5, x6):
     ret = F.make_tuple()
     predict = self.network(x1, x2, x3, x4, x5, x6)
     if self.num_output == 1 and self.output_is_tuple == 0:
         return predict * self.cast(self.one, self.dtype(predict))
     for i in range(self.num_output):
         ret = ret + F.make_tuple(
             predict[i] * self.cast(self.one, self.dtype(predict[i])))
     return ret