def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] # Args kernel = self.forward_func.info.args["kernel"] stride = self.forward_func.info.args["stride"] ignore_border = self.forward_func.info.args["ignore_border"] pad = self.forward_func.info.args["pad"] channel_last = self.forward_func.info.args["channel_last"] # TODO: BHWC assert channel_last == False, "`channel_last = False` is only supported now." # Inputs x0 = inputs[0].data dy = inputs[1].data # Outputs dx0 = outputs[0].data # Grads of inputs g_x0 = inputs[0].grad g_dy = inputs[1].grad # Grads of outputs g_dx0 = outputs[0].grad # Computation if prop_down[1]: g_dy_ = F.sum_pooling(g_dx0, kernel, stride, ignore_border, pad, channel_last) if accum[1]: g_dy += g_dy_ else: g_dy.copy_from(g_dy_)
def sum_pooling_data_grad_backward(inputs, kernel, stride=None, ignore_border=True, pad=None, channel_last=False): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ gdx = inputs[0] gdy = F.sum_pooling(gdx, kernel, stride, ignore_border, pad, channel_last) return gdy