def define_graph(self, inp, weight, grad, inp_grad, weight_grad, optimizer, optimizer_kwargs, stride=1, pad=0, dilation=1): min_sizes = [] k = len(grad.shape) - 2 for d in range(k): min_sizes.append((grad.shape[d + 2] - 1) * stride - 2 * pad + (weight.shape[-1] - 1) * dilation + 1) grad_input_padding = tuple(inp.shape[-k + d] - min_sizes[d] for d in range(k)) assert grad_input_padding[0] == grad_input_padding[1] pm.conv_transpose(grad, weight, inp_grad, stride=stride, pad=pad, out_pad=grad_input_padding[0]) inp_indices = tuple(pm.index(0, s - 1) for s in inp.shape) grad_indices = tuple(pm.index(0, s - 1) for s in grad.shape) weight_indices = tuple(pm.index(0, s - 1) for s in weight.shape) inp_transposed = pm.temp(name=f"transposed_{inp.name}", shape=(inp.shape[1], inp.shape[0], inp.shape[2], inp.shape[3])) grad_transposed = pm.state(name=f"transposed_{grad.name}", shape=(grad.shape[1], grad.shape[0], grad.shape[2], grad.shape[3])) wgt_grad_transposed = pm.temp(name=f"transposed_{weight.name}", shape=(weight.shape[1], weight.shape[0], weight.shape[2], weight.shape[3])) pm.tensor_transpose(inp, inp_transposed, perm=(1, 0, 2, 3)) pm.tensor_transpose(grad, grad_transposed, perm=(1, 0, 2, 3)) pm.conv(inp_transposed, grad_transposed, wgt_grad_transposed, stride=dilation, pad=pad, dilation=stride) pm.tensor_transpose(wgt_grad_transposed, weight_grad, perm=(1, 0, 2, 3)) # Weight update OPTIMIZERS[optimizer](weight, weight_grad, **optimizer_kwargs)
def get_conv_transpose(x, w, bias=None, dilations=None, group=None, kernel_shape=None, pads=None, auto_pad=None, output_padding=None, strides=None, shape=None, name=None, out=None): if not out: out = pm.output(shape=shape, name=name) if auto_pad: h_out = np.ceil(x.shape[-2] / strides[0]) w_out = np.ceil(x.shape[-1] / strides[1]) ph = max(0, (h_out - 1) * strides[0] + kernel_shape[0] - x.shape[-2]) pw = max(0, (w_out - 1) * strides[1] + kernel_shape[1] - x.shape[-1]) pads = [0, 0, 0, 0] if auto_pad == "SAME_LOWER": pads[0] = np.floor(ph // 2) pads[1] = ph - pads[0] pads[2] = np.floor(pw // 2) pads[3] = pw - pads[2] elif auto_pad == "SAME_UPPER": pads[1] = np.floor(ph // 2) pads[0] = ph - pads[1] pads[3] = np.floor(pw // 2) pads[2] = pw - pads[3] if bias: pm.conv_transpose_bias(x, w, bias, out, int(strides[0]), int(pads[-2]), out_pad=output_padding) return out else: pm.conv_transpose(x, w, out, int(strides[0]), int(pads[-2]), out_pad=output_padding) return out
def test_conv2d_transpose_shapes(inp_shape, wgt_shape, stride, pad): groups = 1 dilation = 1 out_pad = 0 inp = np.random.randint(-15, 15, np.prod(inp_shape)).reshape(inp_shape) wgt = np.random.randint(-15, 15, np.prod(wgt_shape)).reshape(wgt_shape) torch_res = F.conv_transpose2d(torch.from_numpy(inp), torch.from_numpy(wgt), stride=stride, padding=pad) torch_res = conv2d_transpose(torch.from_numpy(inp), torch.from_numpy(wgt), stride, pad) # np.testing.assert_allclose(tres.numpy(), torch_res.numpy()) info = { 'data': inp, 'w': wgt, } N, C, H, W = inp.shape x = pm.input(name="data", shape=inp_shape) w = pm.state(name="w", shape=wgt_shape) out = pm.output(name="out") graph = pm.conv_transpose(x, w, out, stride, pad) # tres = graph("out", info) np.testing.assert_allclose(tres, torch_res.numpy())