if __name__ == '__main__':

    device = "/gpu:0" if USE_GPU else "/cpu:0"
    network_scope = TASK_TYPE
    list_of_tasks = TASK_LIST
    scene_scopes = list_of_tasks.keys()
    global_t = 0
    stop_requested = False

    if not os.path.exists(CHECKPOINT_DIR):
        os.mkdir(CHECKPOINT_DIR)

    if not os.path.exists(CHECKPOINT_DIR_beta):
        os.mkdir(CHECKPOINT_DIR_beta)

    initial_learning_rate = log_uniform(INITIAL_ALPHA_LOW, INITIAL_ALPHA_HIGH,
                                        INITIAL_ALPHA_LOG_RATE)

    global_network = ActorCriticFFNetwork(action_size=ACTION_SIZE,
                                          device=device,
                                          network_scope=network_scope,
                                          scene_scopes=scene_scopes)

    branches = []
    for scene in scene_scopes:
        for task in list_of_tasks[scene]:
            branches.append((scene, task))

    NUM_TASKS = len(branches)
    assert PARALLEL_SIZE >= NUM_TASKS, \
      "Not enough threads for multitasking: at least {} threads needed.".format(NUM_TASKS)
예제 #2
0
from constants import TASK_LIST

if __name__ == '__main__':

  device = "/gpu:0" if USE_GPU else "/cpu:0"
  network_scope = TASK_TYPE
  list_of_tasks = TASK_LIST
  scene_scopes = list_of_tasks.keys()
  global_t = 0
  stop_requested = False

  if not os.path.exists(CHECKPOINT_DIR):
    os.mkdir(CHECKPOINT_DIR)

  initial_learning_rate = log_uniform(INITIAL_ALPHA_LOW,
                                      INITIAL_ALPHA_HIGH,
                                      INITIAL_ALPHA_LOG_RATE)

  global_network = ActorCriticFFNetwork(action_size = ACTION_SIZE,
                                        device = device,
                                        network_scope = network_scope,
                                        scene_scopes = scene_scopes)

  branches = []
  for scene in scene_scopes:
    for task in list_of_tasks[scene]:
      branches.append((scene, task))

  NUM_TASKS = len(branches)
  assert PARALLEL_SIZE >= NUM_TASKS, \
    "Not enough threads for multitasking: at least {} threads needed.".format(NUM_TASKS)