def testMatMulBroadcast(self): with self.session() as sess: with ops.device("/device:IPU:0"): in0 = array_ops.placeholder(np.float16, shape=[1024]) in0_bcast = gen_array_ops.broadcast_to(in0, shape=[1024, 1024]) in1 = array_ops.placeholder(np.float16, shape=[1024, 1024]) with variable_scope.variable_scope("vs", use_resource=True): weights = variable_scope.get_variable( "x", dtype=np.float16, shape=[1024, 1024], initializer=init_ops.constant_initializer(0.0)) mm1 = math_ops.matmul(in0_bcast, weights, name="mm1") mm2 = math_ops.matmul(in1, mm1, name="mm2") report = ReportJSON(self, sess) tu.move_variable_initialization_to_cpu() sess.run(variables.global_variables_initializer()) report.reset() sess.run(mm2, {in0: np.zeros(in0.shape), in1: np.zeros(in1.shape)}) report.parse_log() report.assert_total_tile_memory(112509300) report.assert_max_tile_memory(100438) ok = ['__seed*', 'host-exchange-local-copy-', 'mm1/dot*', 'Copy_'] report.assert_all_compute_sets_and_list(ok)
def testNormCacheConstants(self): with self.session() as sess: def model(x, y, z): scale = gen_array_ops.broadcast_to(z, shape=[65536]) offset = scale b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments') a = nn.fused_batch_norm(x, scale, offset, b_mean, b_var, 1e-3, is_training=False, name="a") b = nn.fused_batch_norm(y, scale, offset, b_mean, b_var, 1e-3, is_training=False, name="b") return a[0] + b[0] with ops.device('cpu'): x = array_ops.placeholder(np.float16, [1, 1, 1, 65536], name="x") y = array_ops.placeholder(np.float16, [1, 1, 1, 65536], name="y") z = array_ops.placeholder(np.float16, shape=[1]) with ops.device("/device:IPU:0"): res = ipu_compiler.compile(model, inputs=[x, y, z]) report = ReportJSON(self, sess) tu.move_variable_initialization_to_cpu() sess.run(variables.global_variables_initializer()) report.reset() r = sess.run(res, { x: np.ones(x.shape), y: np.ones(y.shape), z: [1.0] }) self.assertAllClose(r[0], np.full(r[0].shape, 2)) report.parse_log() report.assert_total_tile_memory(1634674) report.assert_max_tile_memory(1551) # Would fail if there were two batch norms in the graph ok = [ '__seed*', 'host-exchange-local-copy', 'Copy_', 'moments/SquaredDifference/multiply', 'a/batch-norm-inference', 'add/add*/Add', ] report.assert_all_compute_sets_and_list(ok)