def main(args): """ Generate SC2 replays. :param args: Dict[str, Any] :return: """ print_ascii_logo() print('Saving replays... Press Ctrl+C to stop.') log_dir_helper = LogDirHelper(args.log_id_dir) with open(log_dir_helper.args_file_path(), 'r') as args_file: train_args = DotDict(json.load(args_file)) engine = env_registry.lookup_engine(train_args.env) assert engine == 'AdeptSC2Env', "replay_gen_sc2.py is only for SC2." # construct env env = SubProcEnvManager.from_args( train_args, seed=args.seed, nb_env=1, registry=env_registry, sc2_replay_dir=log_dir_helper.epoch_path_at_epoch(args.epoch), sc2_render=args.render) output_space = agent_registry.lookup_output_space(train_args.agent, env.action_space) if args.custom_network: network = net_registry.lookup_custom_net( train_args.custom_network).from_args(train_args, env.observation_space, output_space, net_registry) else: network = ModularNetwork.from_args(train_args, env.observation_space, output_space, net_registry) # create an agent (add act_eval method) device = torch.device("cuda:{}".format(args.gpu_id) if ( torch.cuda.is_available() and args.gpu_id >= 0) else "cpu") torch.backends.cudnn.benchmark = True agent = agent_registry.lookup_agent(train_args.agent).from_args( train_args, network, device, env_registry.lookup_reward_normalizer(train_args.env), env.gpu_preprocessor, env_registry.lookup_policy(env.engine)(env.action_space), nb_env=1) # create a rendering container # TODO: could terminate after a configurable number of replays instead of running indefinitely renderer = ReplayGenerator(agent, device, env) try: renderer.run() finally: env.close()
def test_forward(self): import torch BATCH = 32 obs = { 'source_1d': torch.zeros(( BATCH, 16, )), 'source_2d': torch.zeros((BATCH, 16, 8 * 8)), 'source_3d': torch.zeros((BATCH, 16, 8, 8)), 'source_4d': torch.zeros((BATCH, 16, 8, 8, 8)) } try: net = ModularNetwork(self.source_nets, self.body, self.heads, self.output_space, dummy_gpu_preprocessor) outputs, _ = net.forward(obs, {}) except: self.fail('Unexpected exception')
def test_heads_match_out_shapes(self): stub_2d = Identity2D((32, 32), 'stub_2d') source_nets = {'source': stub_2d} body = stub_2d heads = {'2': stub_2d} output_space = {'output': (32, 64)} # should error with self.assertRaises(AssertionError): ModularNetwork(source_nets, body, heads, output_space, dummy_gpu_preprocessor)
def test_valid_structure(self): try: ModularNetwork( self.source_nets, self.body, self.heads, self.output_space, dummy_gpu_preprocessor, ) except: self.fail("Unexpected exception")
def test_output_has_a_head(self): stub_2d = Identity2D((32, 32), "stub_2d") source_nets = {"source": stub_2d} body = stub_2d heads = {"2": stub_2d} output_space = {"output": (32, 32, 32)} # should error with self.assertRaises(AssertionError): ModularNetwork( source_nets, body, heads, output_space, dummy_gpu_preprocessor )
def test_body_matches_heads(self): stub_32 = Identity2D((32, 32), 'stub_32') stub_64 = Identity2D((32, 64), 'stub_64') source_nets = {'source': stub_32} body = stub_32 heads = {'2': stub_64} # should error output_space = {'output': (32, 64)} with self.assertRaises(AssertionError): ModularNetwork(source_nets, body, heads, output_space, dummy_gpu_preprocessor)
def test_heads_not_higher_dim_than_body(self): stub_1d = Identity1D((32, ), 'stub_1d') stub_2d = Identity2D((32, 32), 'stub_2d') source_nets = {'source': stub_1d} body = stub_1d heads = {'2': stub_2d} output_space = {'output': (32, 32)} with self.assertRaises(AssertionError): ModularNetwork(source_nets, body, heads, output_space, dummy_gpu_preprocessor)
def test_source_nets_match_body(self): stub_32 = Identity2D((32, 32), "stub_32") stub_64 = Identity2D((32, 64), "stub_64") source_nets = {"source": stub_32} body = stub_64 # should error heads = {"2": stub_64} output_space = {"output": (32, 64)} with self.assertRaises(AssertionError): ModularNetwork( source_nets, body, heads, output_space, dummy_gpu_preprocessor )