def test_conv(n64_hw32_c32_3x3): cf = ConvParams(**n64_hw32_c32_3x3) inputs = ng.placeholder(axes=cf.ax_i) filters = ng.placeholder(axes=cf.ax_f) # randomly initialize input_value = rng.uniform(-0.5, 0.5, cf.ax_i) filter_value = rng.uniform(-0.5, 0.5, cf.ax_f) error_value = rng.uniform(-0.5, 0.5, cf.ax_o) inputs = ng.placeholder(cf.ax_i) filters = ng.placeholder(cf.ax_f) errors = ng.placeholder(cf.ax_o) output = ng.convolution(cf.conv_params, inputs, filters, axes=cf.ax_o) bprop_out = bprop_conv(errors, inputs, filters, output) updat_out = update_conv(errors, inputs, filters, output) with executor([output, bprop_out, updat_out], inputs, filters, errors) as conv_executor: result_ng, gradI_ng, gradF_ng = conv_executor(input_value, filter_value, error_value) # Compute reference with NumPy result_np, gradI_np, gradF_np = reference_conv(cf.dimI, cf.dimF, cf.dimO, cf.conv_params, input_value, filter_value, error_value) # Compare fprop assert np.allclose(result_ng, result_np, rtol=0, atol=0.5) # Compare bprop assert np.allclose(gradI_ng, gradI_np, rtol=0, atol=0.5) # Compare update assert np.allclose(gradF_ng, gradF_np, rtol=0, atol=2)
def fuse_conv_and_bias_callback_bprop(self, op, label_map_op_list): """ """ for (label_map, op) in label_map_op_list: bprop_conv_new_op = bprop_conv(op.args[0], op.fprop.args[0], op.args[1], op.fprop) try: new_conv_fprop_op = self.op_replacement_dict[op.fprop] bprop_conv_new_op.fprop = new_conv_fprop_op self.replace_op(op, bprop_conv_new_op) except KeyError: return
def visit(self, op, delta, filters): replace = False # If we have updated op.fprop in this pass, replace this op # with a version with the replaced fprop. fprop = op.fprop replacement_fprop = self.get_replacement(fprop) if replacement_fprop is not None: replace = True fprop = replacement_fprop if replace: self.replace_op(op, bprop_conv(delta, self.op_arg(fprop, 0), filters, fprop))
def test_conv(transformer_factory): """ TODO: make this more interesting """ N, C, K = 64, 32, 32 D, H, W = 1, 32, 32 T, R, S = 1, 3, 3 pad_d, pad_h, pad_w = 0, 0, 0 str_d, str_h, str_w = 1, 1, 1 dil_d, dil_h, dil_w = 1, 1, 1 M = output_dim(D, T, pad_d, str_d) P = output_dim(H, R, pad_h, str_h) Q = output_dim(W, S, pad_w, str_w) padding = dict(pad_d=pad_d, pad_h=pad_h, pad_w=pad_w) strides = dict(str_d=str_d, str_h=str_h, str_w=str_w) dilation = dict(dil_d=dil_d, dil_h=dil_h, dil_w=dil_w) conv_params = padding.copy() conv_params.update(strides) conv_params.update(dilation) ax_i = ng.make_axes([ ng.make_axis(name='C'), ng.make_axis(name='D'), ng.make_axis(name='H'), ng.make_axis(name='W'), ax.N ]) ax_f = ng.make_axes([ ng.make_axis(name='C'), ng.make_axis(name='D'), ng.make_axis(name='H'), ng.make_axis(name='W'), ng.make_axis(name='K'), ]) ax_o = ng.make_axes([ ng.make_axis(name='C'), ng.make_axis(name='D'), ng.make_axis(name='H'), ng.make_axis(name='W'), ax.N ]) ax_i.set_shape((C, D, H, W, N)) ax_f.set_shape((C, T, R, S, K)) ax_o[:-1].set_shape((K, M, P, Q)) inputs = ng.placeholder(axes=ax_i) filters = ng.placeholder(axes=ax_f) # randomly initialize input_value = rng.uniform(-0.5, 0.5, ax_i) filter_value = rng.uniform(-0.5, 0.5, ax_f) error_value = rng.uniform(-0.5, 0.5, ax_o) assert input_value.shape == ax_i.lengths assert filter_value.shape == ax_f.lengths inputs = ng.placeholder(ax_i) filters = ng.placeholder(ax_f) errors = ng.placeholder(ax_o) output = ng.convolution(conv_params, inputs, filters, axes=ax_o) bprop_out = bprop_conv(errors, inputs, filters, output) updat_out = update_conv(errors, inputs, filters, output) with executor([output, bprop_out, updat_out], inputs, filters, errors) as conv_executor: result_ng, gradI_ng, gradF_ng = conv_executor(input_value, filter_value, error_value) # Compute reference with NumPy result_np, gradI_np, gradF_np = reference_conv(C, N, K, D, H, W, T, R, S, M, P, Q, pad_d, pad_h, pad_w, str_d, str_h, str_w, input_value, filter_value, error_value) # Compare fprop assert np.allclose(result_ng, result_np, rtol=0, atol=0.5) # Compare bprop assert np.allclose(gradI_ng, gradI_np, rtol=0, atol=0.5) # Compare update assert np.allclose(gradF_ng, gradF_np, rtol=0, atol=2)
def op_from_args(self, op, args): return bprop_conv(args[0], op.fprop.args[0], args[1], op.fprop)