def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype(np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile.
def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype(np.int32) program = program[program != bos_token] try: return dsl.decode_program(program.tolist(), id_token_table) except: # pylint: disable=bare-except return None # Program does not compile.
def decode_program(program): """Decode program tokens.""" # Concatenate all partial programs. full_program = [] for p in program: full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32)) full_program = np.concatenate([full_program, [eos_token]], axis=0) try: return dsl.decode_program(full_program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile.
def test_decode(self): id_token_table, token_id_table = tokens.build_token_tables() self.assertEqual(len(token_id_table), len(id_token_table)) program = dsl.Concat( dsl.Compose( dsl.Replace(' ', ','), dsl.GetSpan(dsl.Type.PROP_CASE, 1, dsl.Boundary.START, dsl.Type.PROP_CASE, 4, dsl.Boundary.END)), dsl.ConstStr('.'), dsl.GetToken(dsl.Type.PROP_CASE, -1)) encoding = program.encode(token_id_table) self.assertEqual(encoding[-1], token_id_table[dsl.EOS]) decoded_program = dsl.decode_program(encoding, id_token_table) self.assertEqual( decoded_program('Jacob Ethan James Alexander Michael'), 'Jacob,Ethan,James,Alexander.Michael') self.assertEqual(decoded_program('Earth Fire Wind Water Pluto Sun'), 'Earth,Fire,Wind,Water.Sun')
def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_id) + 1].astype(np.int32) if FLAGS.dataset_type == 'robust_fill': # Returns either a Concat program object, or None. program = program[program != bos_id].tolist() try: return robust_fill_dsl.decode_program(program, program_id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. elif FLAGS.dataset_type == 'scan': # Returns a string. program = program[jnp.logical_and(program != bos_id, program != eos_id)].tolist() return ' '.join(scan_vocab.decode(program, program_id_token_table)) else: raise ValueError('Unhandled dataset_type: {}'.format(FLAGS.dataset_type))