Ejemplo n.º 1
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Args
        kernel = self.forward_func.info.args["kernel"]
        stride = self.forward_func.info.args["stride"]
        ignore_border = self.forward_func.info.args["ignore_border"]
        pad = self.forward_func.info.args["pad"]
        channel_last = self.forward_func.info.args["channel_last"]

        # TODO: BHWC
        assert channel_last == False, "`channel_last = False` is only supported now."

        # Inputs
        x0 = inputs[0].data
        dy = inputs[1].data
        # Outputs
        dx0 = outputs[0].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_dy = inputs[1].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad

        # Computation
        if prop_down[1]:
            g_dy_ = F.sum_pooling(g_dx0, kernel, stride, ignore_border, pad,
                                  channel_last)
            if accum[1]:
                g_dy += g_dy_
            else:
                g_dy.copy_from(g_dy_)
Ejemplo n.º 2
0
def sum_pooling_data_grad_backward(inputs,
                                   kernel,
                                   stride=None,
                                   ignore_border=True,
                                   pad=None,
                                   channel_last=False):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    gdx = inputs[0]
    gdy = F.sum_pooling(gdx, kernel, stride, ignore_border, pad, channel_last)
    return gdy