def add_diff_rot90(g_outmap): g_out = g_outmap["output"] grid_idx = g_outmap['grid_idx'] z_rot90 = g_outmap['z_rot90'] alphas = binary_mask(grid_idx, black=0.5, ignore=1.0, white=0.5) bw_mask = binary_mask(grid_idx, black=0., ignore=0, white=0.5) combined = (alphas * g_out) + bw_mask return rotate_by_multiple_of_90(combined, z_rot90)
def add_diff_rot90(g_outmap): g_out = g_outmap["output"] grid_idx = g_outmap['grid_idx'] z_rot90 = g_outmap['z_rot90'] alphas = binary_mask(grid_idx, black=0.5, ignore=1.0, white=0.5) bw_mask = binary_mask(grid_idx, black=0., ignore=0, white=0.5) combined = (alphas * g_out) + bw_mask return rotate_by_multiple_of_90(combined, z_rot90)
def reconstruct(g_outmap): g_out = g_outmap["output"] grid_idx = g_outmap["grid_idx"] z_rot90 = g_outmap['z_rot90'] alphas = binary_mask(grid_idx, black=variation_weight, ignore=1.0, white=variation_weight) m = theano.gradient.disconnected_grad(g_out[:, :1]) v = g_out[:, 1:] combined = v # T.clip(m + alphas*v, 0., 1.) return rotate_by_multiple_of_90(combined, z_rot90)
def reconstruct(g_outmap): g_out = g_outmap["output"] grid_idx = g_outmap["grid_idx"] z_rot90 = g_outmap['z_rot90'] alphas = binary_mask(grid_idx, black=variation_weight, ignore=1.0, white=variation_weight) m = theano.gradient.disconnected_grad(g_out[:, :1]) v = g_out[:, 1:] combined = v # T.clip(m + alphas*v, 0., 1.) return rotate_by_multiple_of_90(combined, z_rot90)
def test_util_rotate_by_multiple_of_90(batch): n = len(batch) th_batch = K.variable(batch) rots = K.variable(np.array([0, 1, 2, 3])) rotated = rotate_by_multiple_of_90(th_batch, rots).eval() for i in range(n): plt.subplot(131) plt.imshow(batch[i, 0]) plt.subplot(132) plt.imshow(rotated[i, 0]) plt.subplot(133) plt.imshow(np.rot90(batch[i, 0], k=i)) plt_save_and_maybe_show("utils/rotate_{}.png".format(i)) assert rotated.shape == batch.shape for i in range(n): assert (rotated[i, 0] == np.rot90(batch[i, 0], k=i)).all(), i
def test_util_rotate_by_multiple_of_90_missing_rots(batch): th_batch = K.variable(batch) rots = K.variable(np.array([0, 0, 2, 2])) rotated = rotate_by_multiple_of_90(th_batch, rots).eval() assert rotated.shape == batch.shape