Пример #1
0
def set_tf_output(output_tensors):
  res = get_collection(BATCHNORM_TENSORS, cur_model_scope())
  if res is None:
    return
  assert(len(res) == len(output_tensors))
  update_ops = []  
  for i in range(len(res)):
    update_op = xdl.ps_apply_moving_average_op(
      var_name = res[i][0].name, value = output_tensors[i], moment = res[i][2])
    update_ops.append(update_op)
  add_to_collection(UPDATE_OPS, update_ops)
Пример #2
0
 def test_all(self):
     var = xdl.Variable(name="w",
                        dtype=DataType.float,
                        shape=[4],
                        initializer=xdl.Ones())
     execute(xdl.variable_registers())
     execute(xdl.global_initializers())
     op = xdl.ps_apply_moving_average_op(var_name="w",
                                         moment=0.8,
                                         value=np.array([1, 2, 3, 3],
                                                        dtype=np.float32))
     execute(op)
     ret = execute(var.value)
     self.assertTrue((ret == np.array([1., 1.2, 1.4, 1.4],
                                      dtype=np.float32)).all())
Пример #3
0
        def _wrapper(*inputs, **kwargs):
            add_to_collection(BACKEND_DEVICE_TYPE, device_type.lower())
            sym_input_dict = {}
            placeholders = []
            for x in inputs:
                placeholder = recursive_make_placeholder(x, sym_input_dict)
                placeholders.append(placeholder)

            gear_input_num = 0
            if 'gear_inputs' in kwargs:
                gear_inputs = kwargs['gear_inputs']
                gear_placeholder = recursive_make_placeholder(
                    gear_inputs, sym_input_dict, True)
                kwargs['gear_inputs'] = gear_placeholder
                gear_input_num = len(flatten(gear_inputs))

            model_outputs = model_func(*placeholders, **kwargs)
            if len(model_outputs) == 0:
                raise Exception('model_func must return loss')
            symbol_list = list(model_outputs)
            bn_statistic = get_collection(MXNET_BN_STATISTIC)
            bn_var_names = []
            bn_syms = []
            moments = []
            if bn_statistic is not None and len(bn_statistic) > 0:
                bn_var_names.extend([x[0] for x in bn_statistic])
                bn_syms.extend([x[1] for x in bn_statistic])
                moments.extend([x[2] for x in bn_statistic])

            symbol_list.extend([mx.sym.BlockGrad(x) for x in bn_syms])
            symbol = mx.sym.Group(symbol_list)
            executor = symbol.simple_bind(ctx=mx.cpu())
            add_variable_inputs(symbol,
                                sym_input_dict,
                                is_training=is_training)
            sym_names = symbol.list_arguments()
            xdl_inputs = []
            for sym in sym_names:
                xdl_inputs.append(sym_input_dict[sym])

            for aux in symbol.list_auxiliary_states():
                if aux in sym_input_dict:
                    xdl_inputs.append(sym_input_dict[aux])
                    sym_names.append(aux)

            target_size = len(executor.outputs)
            gradient_size = len(executor.grad_arrays)
            if device_type.lower() == 'cpu':
                outputs, gradients = xdl.mxnet_backend_op(
                    inputs=xdl_inputs,
                    var_name_str=','.join(sym_names),
                    device_type=device_type.lower(),
                    graph_def=serialize_graph(symbol),
                    target_size=target_size,
                    gradient_size=gradient_size if is_training else 0,
                    is_training=is_training,
                    init_grad=init_grad if init_grad is not None else np.array(
                        [], dtype=np.float32),
                    has_init_grad=True if init_grad is not None else False)
            else:
                with xdl.device('GPU'):
                    outputs, gradients = xdl.mxnet_backend_op(
                        inputs=xdl_inputs,
                        var_name_str=','.join(sym_names),
                        device_type=device_type.lower(),
                        graph_def=serialize_graph(symbol),
                        target_size=target_size,
                        gradient_size=gradient_size if is_training else 0,
                        is_training=is_training,
                        init_grad=init_grad if init_grad is not None else
                        np.array([], dtype=np.float32),
                        has_init_grad=True if init_grad is not None else False)

            bn_var_num = len(bn_var_names)
            if bn_var_num > 0:
                bn_outputs = outputs[len(outputs) - bn_var_num:]
                outputs = outputs[0:len(outputs) - bn_var_num]
                bn_update_infos = zip(bn_var_names, bn_outputs, moments)
                add_to_collection(BN_STATISTIC, bn_update_infos)
                update_ops = []
                for n, v, m in bn_update_infos:
                    update_op = xdl.ps_apply_moving_average_op(var_name=n,
                                                               value=v,
                                                               moment=m)
                    update_ops.append(update_op)
                add_to_collection(UPDATE_OPS, update_ops)

            if is_training:
                sym_names_ = []
                gradients_ = []
                if gear_input_num > 0:
                    global _GEAR_INPUTS
                    gear_grads = [None] * gear_input_num
                    for i in range(len(sym_names)):
                        if sym_names[i] not in _GEAR_INPUTS:
                            gradients_.append(gradients[i])
                            sym_names_.append(sym_names[i])
                        else:
                            index = _GEAR_INPUTS.index(sym_names[i])
                            gear_grads[index] = gradients[i]
                    for i in range(len(gear_inputs)):
                        set_gear_gradient(gear_inputs[i], gear_grads[i])
                    add_to_collection(GEAR_GRAD, gear_grads, cur_model_scope())
                    set_gradients(sym_names_, gradients_, cur_model_scope())
                else:
                    set_gradients(sym_names, gradients, cur_model_scope())
            return outputs