コード例 #1
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(
        xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
    )
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
    )

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda:0")
        flags.actor_device = torch.device("cuda:1")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
    model = model.to(device=flags.learner_device)

    actor_model = Net(num_actions=flags.num_actions, use_lstm=flags.use_lstm)
    actor_model.to(device=flags.actor_device)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return (
            1
            - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps)
            / flags.total_steps
        )

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(
            checkpointpath, map_location=flags.learner_device
        )
        model.load_state_dict(checkpoint_states["model_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                model,
                actor_model,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        )
        for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(flags, inference_batcher, actor_model),
        )
        for i in range(flags.num_inference_threads)
    ]

    actorpool_thread.start()
    for t in learner_threads + inference_threads:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) / (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(
                    f"{key} = {format_value(value)}" for key, value in stats.items()
                ),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in learner_threads + inference_threads:
        t.join()
コード例 #2
0
def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    # prepare for logging and saving models
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))
    if flags.save_model_every_nsteps > 0:
        os.makedirs(checkpointpath.replace("model.tar", "intermediate"),
                    exist_ok=True)

    # get a list and determine the number of environments
    environments = flags.env.split(",")
    flags.num_tasks = len(environments)

    # set the number of buffers
    if flags.num_buffers is None:
        flags.num_buffers = max(2 * flags.num_actors * flags.num_tasks,
                                flags.batch_size)
    if flags.num_actors * flags.num_tasks >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    # set the device to do the training on
    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    # set the environments
    if flags.env == "six":
        flags.env = "AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4," \
                    "NameThisGameNoFrameskip-v4,PongNoFrameskip-v4,SpaceInvadersNoFrameskip-v4"
    elif flags.env == "three":
        flags.env = "AirRaidNoFrameskip-v4,CarnivalNoFrameskip-v4,DemonAttackNoFrameskip-v4"

    # set the right agent class
    if flags.agent_type.lower() in [
            "aaa", "attention_augmented", "attention_augmented_agent"
    ]:
        Net = AttentionAugmentedAgent
        logging.info("Using the Attention-Augmented Agent architecture.")
    elif flags.agent_type.lower() in ["rn", "res", "resnet", "res_net"]:
        Net = ResNet
        logging.info("Using the ResNet architecture (monobeast version).")
    else:
        Net = AtariNet
        logging.warning(
            "No valid agent type specified. Using the default agent.")

    # create a dummy environment, mostly to get the observation and action spaces from
    gym_env = create_env(environments[0],
                         frame_height=flags.frame_height,
                         frame_width=flags.frame_width,
                         gray_scale=(flags.aaa_input_format == "gray_stack"))
    observation_space_shape = gym_env.observation_space.shape
    action_space_n = gym_env.action_space.n
    full_action_space = False
    for environment in environments[1:]:
        gym_env = create_env(environment)
        if gym_env.action_space.n != action_space_n:
            logging.warning(
                "Action spaces don't match, using full action space.")
            full_action_space = True
            action_space_n = 18
            break

    # create the model and the buffers to pass around data between actors and learner
    model = Net(observation_space_shape,
                action_space_n,
                use_lstm=flags.use_lstm,
                num_tasks=flags.num_tasks,
                use_popart=flags.use_popart,
                reward_clipping=flags.reward_clipping,
                rgb_last=(flags.aaa_input_format == "rgb_last"))
    buffers = create_buffers(flags, observation_space_shape, model.num_actions)

    # I'm guessing that this is required (similarly to the buffers) so that the
    # different threads/processes can all have access to the parameters etc. (?)
    model.share_memory()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    # create stuff to keep track of the actor processes
    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    # create and start actor threads (the same number for each environment)
    for i, environment in enumerate(environments):
        for j in range(flags.num_actors):
            actor = ctx.Process(
                target=act,
                args=(
                    flags,
                    environment,
                    i,
                    full_action_space,
                    i * flags.num_actors + j,
                    free_queue,
                    full_queue,
                    model,
                    buffers,
                    initial_agent_state_buffers,
                ),
            )
            actor.start()
            actor_processes.append(actor)

    learner_model = Net(observation_space_shape,
                        action_space_n,
                        use_lstm=flags.use_lstm,
                        num_tasks=flags.num_tasks,
                        use_popart=flags.use_popart,
                        reward_clipping=flags.reward_clipping,
                        rgb_last=(flags.aaa_input_format == "rgb_last")).to(
                            device=flags.device)

    # the hyperparameters in the paper are found/adjusted using population-based training,
    # which might be a bit too difficult for us to do; while the IMPALA paper also does
    # some experiments with this, it doesn't seem to be implemented here
    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.device)
        learner_model.load_state_dict(checkpoint_states["model_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        # stats = checkpoint_states["stats"]
        # logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    model.load_state_dict(learner_model.state_dict())

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
        "mu",
        "sigma",
    ] + ["{}_step".format(e) for e in environments]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        # step in particular needs to be from the outside scope, since all learner threads can update
        # it and all learners should stop once the total number of steps/frames has been processed
        nonlocal step, stats
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            learn(flags,
                  model,
                  learner_model,
                  batch,
                  agent_state,
                  optimizer,
                  scheduler,
                  stats,
                  envs=environments)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update(
                    {k: stats[k]
                     for k in stat_keys if "_step" not in k})
                for e in stats["env_step"]:
                    to_log["{}_step".format(e)] = stats["env_step"][e]
                plogger.log(to_log)
                step += T * B  # so this counts the number of frames, not e.g. trajectories/rollouts

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())

    # populate the free queue with the indices of all the buffers at the start
    for m in range(flags.num_buffers):
        free_queue.put(m)

    # start as many learner threads as specified => could in principle do PBT
    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(target=batch_and_learn,
                                  name="batch-and-learn-%d" % i,
                                  args=(i, ))
        thread.start()
        threads.append(thread)

    def save_latest_model():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": learner_model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def save_intermediate_model():
        save_model_path = checkpointpath.replace(
            "model.tar", "intermediate/model." +
            str(stats.get("step", 0)).zfill(9) + ".tar")
        logging.info("Saving model to %s", save_model_path)
        torch.save(
            {
                "model_state_dict": learner_model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            save_model_path,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        last_savemodel_nsteps = 0
        while step < flags.total_steps:
            start_step = stats.get("step", 0)
            start_time = timer()
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timer() - last_checkpoint_time > 10 * 60:
                # save every 10 min.
                save_latest_model()
                last_checkpoint_time = timer()

            if flags.save_model_every_nsteps > 0 and end_step >= last_savemodel_nsteps + flags.save_model_every_nsteps:
                # save model every save_model_every_nsteps steps
                save_intermediate_model()
                last_savemodel_nsteps = end_step

            sps = (end_step - start_step) / (timer() - start_time)
            if stats.get("episode_returns", None):
                mean_return = ("Return per episode: %.1f. " %
                               stats["mean_episode_return"])
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
                end_step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        gradient_tracker.print_total()
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)  # send quit signal to actors
        for actor in actor_processes:
            actor.join(timeout=10)
        gradient_tracker.print_total()

    save_latest_model()
    plogger.close()
コード例 #3
0
ファイル: monobeast.py プロジェクト: nicoladainese96/SC2-RL
def train(flags, game_params):  # pylint: disable=too-many-branches, too-many-statements
    """
    1. Init actor model and create_buffers()
    2. Starts 'num_actors' act() functions
    3. Init learner model and optimizer, loads the former on the GPU
    4. Launches 'num_learner_threads' threads executing batch_and_learn()
    5. train finishes when all batch_and_learn threads finish, i.e. when steps >= flags.total_steps
    """
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))
    print("checkpointpath: ", checkpointpath)
    if flags.num_buffers is None:  # Set sensible default for num_buffers. IMPORTANT!!
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    env = init_game(game_params['env'], flags.map_name)

    model = IMPALA_AC(env=env, device='cpu', **game_params['HPs'])
    observation_shape = (game_params['HPs']['spatial_dict']['in_channels'],
                         *model.screen_res)
    player_shape = game_params['HPs']['spatial_dict']['in_player']
    num_actions = len(model.action_names)
    buffers = create_buffers(flags, observation_shape, player_shape,
                             num_actions, model.max_num_spatial_args,
                             model.max_num_categorical_args)

    model.share_memory()  # see if this works out of the box for my A2C

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                game_params,
                i,
                free_queue,
                full_queue,
                model,  # with share memory
                buffers),
        )
        actor.start()
        actor_processes.append(actor)

    # only model loaded into the GPU ?
    if not flags.disable_cuda and torch.cuda.is_available():
        learner_model = IMPALA_AC(env=env, device='cuda',
                                  **game_params['HPs']).to(device=flags.device)
    else:
        learner_model = IMPALA_AC(env=env, device='cpu',
                                  **game_params['HPs']).to(device=flags.device)

    if flags.optim == "Adam":
        optimizer = torch.optim.Adam(learner_model.parameters(),
                                     lr=flags.learning_rate)
    else:
        optimizer = torch.optim.RMSprop(
            learner_model.parameters(),
            lr=flags.learning_rate,
            momentum=flags.momentum,
            eps=flags.epsilon,
            alpha=flags.alpha,
        )

    def lr_lambda(epoch):
        """
        Linear schedule from 1 to 0 used only for RMSprop. 
        To be adjusted multiplying or not by batch size B depending on how the steps are counted.
        epoch = number of optimizer steps
        total_steps = optimizer steps * time steps * batch size
                    or optimizer steps * time steps
        """
        return 1 - min(epoch * T, flags.total_steps
                       ) / flags.total_steps  #epoch * T * B if using B steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
    ]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                timings,
            )
            stats = learn(flags, model, learner_model, batch, optimizer,
                          scheduler)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T  #* B # just count the parallel steps

    # end batch_and_learn

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(target=batch_and_learn,
                                  name="batch-and-learn-%d" % i,
                                  args=(i, ))
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        if flags.optim == "Adam":
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "flags": vars(flags),
                },
                checkpointpath,  # only one checkpoint at the time is saved
            )
        else:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "flags": vars(flags),
                },
                checkpointpath,  # only one checkpoint at the time is saved
            )

    # end checkpoint

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        while step < flags.total_steps:
            start_step = step
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            sps = (step - start_step) / (timer() - start_time
                                         )  # steps per second
            if stats.get("episode_returns", None):
                mean_return = ("Return per episode: %.1f. " %
                               stats["mean_episode_return"])

            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
                step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    plogger.close()
コード例 #4
0
ファイル: polybeast.py プロジェクト: ln-e/spiralpp
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = actorpool.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    d_queue = Queue(maxsize=flags.max_learner_queue_size // flags.batch_size)
    image_queue = Queue(maxsize=flags.max_learner_queue_size)

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = actorpool.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    config = dict(
        episode_length=flags.episode_length,
        canvas_width=flags.canvas_width,
        grid_width=grid_width,
        brush_sizes=flags.brush_sizes,
    )

    if flags.dataset == "celeba" or flags.dataset == "celeba-hq":
        use_color = True
    else:
        use_color = False

    if flags.env_type == "fluid":
        env_name = "Fluid"
        config["shaders_basedir"] = SHADERS_BASEDIR
    elif flags.env_type == "libmypaint":
        env_name = "Libmypaint"
        config.update(
            dict(
                brush_type=flags.brush_type,
                use_color=use_color,
                use_pressure=flags.use_pressure,
                use_alpha=False,
                background="white",
                brushes_basedir=BRUSHES_BASEDIR,
            ))

    if flags.use_compound:
        env_name += "-v1"
    else:
        env_name += "-v0"

    env = env_wrapper.make_raw(env_name, config)
    if frame_width != flags.canvas_width:
        env = env_wrapper.WarpFrame(env, height=frame_width, width=frame_width)
    env = env_wrapper.wrap_pytorch(env)

    obs_shape = env.observation_space.shape
    if flags.condition:
        c, h, w = obs_shape
        c *= 2
        obs_shape = (c, h, w)

    action_shape = env.action_space.nvec.tolist()
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        model = models.Condition(model)
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
        order=order,
    )
    if flags.condition:
        actor_model = models.Condition(actor_model)
    actor_model.to(device=flags.actor_device)

    D = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D = models.Conditional(D)
    D.to(device=flags.learner_device)

    D_eval = models.Discriminator(obs_shape, flags.power_iters)
    if flags.condition:
        D_eval = models.Conditional(D_eval)
    D_eval = D_eval.to(device=flags.learner_device)

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    D_scheduler = torch.optim.lr_scheduler.LambdaLR(D_optimizer, lr_lambda)

    C, H, W = obs_shape
    if flags.condition:
        C //= 2
    # The ActorPool that will run `flags.num_actors` many loops.
    actors = actorpool.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_action=actor_model.initial_action(),
        initial_agent_state=actor_model.initial_state(),
        image=torch.zeros(1, 1, C, H, W),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    c, h, w = obs_shape
    tsfm = transforms.Compose(
        [transforms.Resize((h, w)),
         transforms.ToTensor()])

    dataset = flags.dataset

    if dataset == "mnist":
        dataset = MNIST(root="./", train=True, transform=tsfm, download=True)
    elif dataset == "omniglot":
        dataset = Omniglot(root="./",
                           background=True,
                           transform=tsfm,
                           download=True)
    elif dataset == "celeba":
        dataset = CelebA(root="./",
                         split="train",
                         target_type=None,
                         transform=tsfm,
                         download=True)
    elif dataset == "celeba-hq":
        dataset = datasets.CelebAHQ(root="./",
                                    split="train",
                                    transform=tsfm,
                                    download=True)
    else:
        raise NotImplementedError

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True)

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["D_scheduler_state_dict"])
        D_scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                d_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
                image_queue,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = [
        threading.Thread(
            target=learn_D,
            name="d_learner-thread-%i" % i,
            args=(
                flags,
                d_queue,
                D,
                D_eval,
                D_optimizer,
                D_scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]
    for thread in d_learner:
        thread.daemon = True

    dataloader_thread = threading.Thread(target=data_loader,
                                         args=(
                                             flags,
                                             dataloader,
                                             image_queue,
                                         ))
    dataloader_thread.daemon = True

    actorpool_thread.start()

    threads = learner_threads + inference_threads
    daemons = d_learner + [dataloader_thread]

    for t in threads + daemons:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "D_scheduler_state_dict": D_scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    actorpool_thread.join()

    for t in threads:
        t.join()
コード例 #5
0
ファイル: monobeast.py プロジェクト: guydav/torchbeast
def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")

    xp_args = flags.__dict__.copy()
    if 'device' in xp_args and isinstance(xp_args['device'], torch.device):
        xp_args['device'] = str(xp_args['device'])

    plogger = file_writer.FileWriter(
        xpid=flags.xpid, xp_args=xp_args, rootdir=flags.savedir
    )
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
    )

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    env = create_env(flags)

    model = Net(env.observation_space.shape, env.action_space.n, flags.use_lstm)
    buffers = create_buffers(flags, env.observation_space.shape, model.num_actions)

    model.share_memory()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                i,
                free_queue,
                full_queue,
                model,
                buffers,
                initial_agent_state_buffers,
            ),
        )
        actor.start()
        actor_processes.append(actor)

    learner_model = Net(
        env.observation_space.shape, env.action_space.n, flags.use_lstm
    ).to(device=flags.device)

    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
    ]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    # Load state from a checkpoint, if possible.
    resume_checkpoint_path = checkpointpath
    if 'resume_checkpoint_path' in flags:
        resume_checkpoint_path = flags.resume_checkpoint_path

    if os.path.exists(resume_checkpoint_path):
        checkpoint_states = torch.load(
            resume_checkpoint_path, map_location=flags.device
        )
        learner_model.load_state_dict(checkpoint_states["model_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    model.load_state_dict(learner_model.state_dict())

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats
        timings = prof.Timings()
        steps = flags.total_steps
        if 'train_steps' in flags:
            steps = flags.train_steps

        while step < steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(
                flags, model, learner_model, batch, agent_state, optimizer, scheduler
            )
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)

                step += T * B

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(
            target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i,)
        )
        thread.start()
        threads.append(thread)

    def checkpoint(path=None):
        if flags.disable_checkpoint:
            return

        if path is None:
            path = checkpointpath

        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            path,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        steps = flags.total_steps
        if 'train_steps' in flags:
            steps = flags.train_steps

        while step < steps:
            start_step = step
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            sps = (step - start_step) / (timer() - start_time)
            if stats.get("episode_returns", None):
                mean_return = (
                    "Return per episode: %.1f. " % stats["mean_episode_return"]
                )
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            print_step = step
            if 'current_step' in flags:
                print_step += flags.current_step
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
                print_step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    if wandb.run is not None and wandb.run.dir is not None:
        save_path = os.path.join(wandb.run.dir, f'model-{step}.tar')
        checkpoint(save_path)
        flags.resume_checkpoint_path = save_path

    plogger.close()

    return step
コード例 #6
0
def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    load_checkpoint = None
    if flags.loaddir is not None:
        loadpath = os.path.join(os.path.expanduser(flags.loaddir), 'model.tar')
        logging.info("Continue training from {}".format(loadpath))
        load_checkpoint = torch.load(loadpath, map_location=flags.device)

    if flags.env.startswith('Coin'):
        obs_shape = (3, 64, 64)  # will deadlock if we try to create env here
        act_shape = 7  # so specify space shapes manually
    else:
        env = create_env(flags.env, flags)
        obs_shape = copy(env.observation_space.shape)
        act_shape = env.action_space.n
        del env

    model = Net(obs_shape, act_shape, flags.use_lstm)
    if load_checkpoint is not None:
        model.load_state_dict(load_checkpoint["model_state_dict"])
    buffers = create_buffers(flags, obs_shape, model.num_actions)
    model.share_memory()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                i,
                free_queue,
                full_queue,
                model,
                buffers,
                initial_agent_state_buffers,
            ),
        )
        actor.start()
        actor_processes.append(actor)

    learner_model = Net(obs_shape, act_shape,
                        flags.use_lstm).to(device=flags.device)
    if load_checkpoint is not None:
        learner_model.load_state_dict(load_checkpoint["model_state_dict"])

    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )
    if load_checkpoint is not None:
        optimizer.load_state_dict(load_checkpoint["optimizer_state_dict"])

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    if load_checkpoint is not None:
        scheduler.load_state_dict(load_checkpoint["scheduler_state_dict"])

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
    ]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(flags, model, learner_model, batch, agent_state,
                          optimizer, scheduler)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T * B

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(target=batch_and_learn,
                                  name="batch-and-learn-%d" % i,
                                  args=(i, ))
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
            },
            checkpointpath,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        last_print_time = timer()
        episode_returns = []
        while step < flags.total_steps:
            start_step = step
            start_time = timer()
            time.sleep(0.5)
            if stats.get("episode_returns", None):
                episode_returns.extend(stats["episode_returns"])
            if timer() - last_print_time < 10.0: continue  # wait 10s to print

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            sps = (step - start_step) / (timer() - start_time)
            if len(episode_returns) > 0:
                mean_return_val = sum(episode_returns) / len(episode_returns)
                mean_return = ("Return per episode: %.1f" % mean_return_val)
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            logging.info(
                "Steps %i @ %.1f SPS Loss %f %s\nEpsd returns:%s\nStats:\n%s",
                step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(episode_returns),
                pprint.pformat(stats),
            )
            last_print_time = timer()
            episode_returns = []
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    plogger.close()
コード例 #7
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.learner_device = torch.device("cuda")
        flags.actor_device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.learner_device = torch.device("cpu")
        flags.actor_device = torch.device("cpu")

    if flags.max_learner_queue_size is None:
        flags.max_learner_queue_size = flags.batch_size

    # The queue the learner threads will get their data from.
    # Setting `minimum_batch_size == maximum_batch_size`
    # makes the batch size static.
    learner_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.max_learner_queue_size,
    )

    # The queue the actorpool stores final render image pairs.
    # A seperate thread will load them to the ReplayBuffer.
    # The batch size of the pairs will be dynamic.
    replay_queue = libtorchbeast.BatchingQueue(
        batch_dim=1,
        minimum_batch_size=flags.batch_size,
        maximum_batch_size=flags.batch_size,
        check_inputs=True,
        maximum_queue_size=flags.batch_size,
    )

    # The "batcher", a queue for the inference call. Will yield
    # "batch" objects with `get_inputs` and `set_outputs` methods.
    # The batch size of the tensors will be dynamic.
    inference_batcher = libtorchbeast.DynamicBatcher(
        batch_dim=1,
        minimum_batch_size=1,
        maximum_batch_size=512,
        timeout_ms=100,
        check_outputs=True,
    )

    addresses = []
    connections_per_server = 1
    pipe_id = 0
    while len(addresses) < flags.num_actors:
        for _ in range(connections_per_server):
            addresses.append(f"{flags.pipes_basename}.{pipe_id}")
            if len(addresses) == flags.num_actors:
                break
        pipe_id += 1

    dataset_is_gray = flags.dataset in ["mnist", "omniglot"]
    grayscale = not dataset_is_gray and not flags.use_color
    dataset_is_gray |= grayscale

    dataset = utils.create_dataset(flags.dataset, grayscale)

    env_name, config = utils.parse_flags(flags)
    env = utils.create_env(env_name, config, dataset_is_gray, dataset=None)

    if flags.condition:
        new_space = env.observation_space.spaces
        c, h, w = new_space["canvas"].shape
        new_space["canvas"] = spaces.Box(low=0,
                                         high=255,
                                         shape=(c * 2, h, w),
                                         dtype=np.uint8)
        env.observation_space = spaces.Dict(new_space)

    obs_shape = env.observation_space["canvas"].shape
    action_shape = env.action_space.nvec
    order = env.order
    env.close()

    model = models.Net(
        obs_shape=obs_shape,
        order=order,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
    )
    model = model.to(device=flags.learner_device)

    actor_model = models.Net(
        obs_shape=obs_shape,
        order=order,
        action_shape=action_shape,
        grid_shape=(grid_width, grid_width),
    ).eval()
    actor_model.to(device=flags.actor_device)

    if flags.condition:
        D = models.ComplementDiscriminator(obs_shape)
    else:
        D = models.Discriminator(obs_shape)
    D.to(device=flags.learner_device)

    if flags.condition:
        D_eval = models.ComplementDiscriminator(obs_shape)
    else:
        D_eval = models.Discriminator(obs_shape)
    D_eval = D_eval.to(device=flags.learner_device).eval()

    optimizer = optim.Adam(model.parameters(), lr=flags.policy_learning_rate)
    D_optimizer = optim.Adam(D.parameters(),
                             lr=flags.discriminator_learning_rate,
                             betas=(0.5, 0.999))

    def lr_lambda(epoch):
        return (1 - min(epoch * flags.unroll_length * flags.batch_size,
                        flags.total_steps) / flags.total_steps)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # The ActorPool that will run `flags.num_actors` many loops.
    actors = libtorchbeast.ActorPool(
        unroll_length=flags.unroll_length,
        learner_queue=learner_queue,
        replay_queue=replay_queue,
        inference_batcher=inference_batcher,
        env_server_addresses=addresses,
        initial_agent_state=actor_model.initial_state(),
    )

    def run():
        try:
            actors.run()
            print("actors are running")
        except Exception as e:
            logging.error("Exception in actorpool thread!")
            traceback.print_exc()
            print()
            raise e

    actorpool_thread = threading.Thread(target=run, name="actorpool-thread")

    dataloader = DataLoader(
        dataset,
        batch_size=flags.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
    )

    stats = {}

    # Load state from a checkpoint, if possible.
    if os.path.exists(checkpointpath):
        checkpoint_states = torch.load(checkpointpath,
                                       map_location=flags.learner_device)
        model.load_state_dict(checkpoint_states["model_state_dict"])
        D.load_state_dict(checkpoint_states["D_state_dict"])
        optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"])
        D_optimizer.load_state_dict(
            checkpoint_states["D_optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"])
        stats = checkpoint_states["stats"]
        logging.info(f"Resuming preempted job, current stats:\n{stats}")

    # Initialize actor model like learner model.
    actor_model.load_state_dict(model.state_dict())
    D_eval.load_state_dict(D.state_dict())

    learner_threads = [
        threading.Thread(
            target=learn,
            name="learner-thread-%i" % i,
            args=(
                flags,
                learner_queue,
                model,
                actor_model,
                D_eval,
                optimizer,
                scheduler,
                stats,
                plogger,
            ),
        ) for i in range(flags.num_learner_threads)
    ]

    inference_threads = [
        threading.Thread(
            target=inference,
            name="inference-thread-%i" % i,
            args=(
                flags,
                inference_batcher,
                actor_model,
            ),
        ) for i in range(flags.num_inference_threads)
    ]

    d_learner = threading.Thread(
        target=learn_D,
        name="d_learner-thread",
        args=(
            flags,
            dataloader,
            replay_queue,
            D,
            D_eval,
            D_optimizer,
            stats,
            plogger,
        ),
    )

    actorpool_thread.start()

    threads = learner_threads + inference_threads

    for t in inference_threads:
        t.start()

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "D_state_dict": D.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "D_optimizer_state_dict": D_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "stats": stats,
                "flags": vars(flags),
            },
            checkpointpath,
        )

    def format_value(x):
        return f"{x:1.5}" if isinstance(x, float) else str(x)

    try:
        while replay_queue.size() < flags.batch_size:
            if learner_queue.size() >= flags.batch_size:
                next(learner_queue)
            time.sleep(0.01)

        d_learner.start()

        for t in learner_threads:
            t.start()

        last_checkpoint_time = timeit.default_timer()
        while True:
            start_time = timeit.default_timer()
            start_step = stats.get("step", 0)
            if start_step >= flags.total_steps:
                break
            time.sleep(5)
            end_step = stats.get("step", 0)

            if timeit.default_timer() - last_checkpoint_time > 10 * 60:
                # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timeit.default_timer()

            logging.info(
                "Step %i @ %.1f SPS. Inference batcher size: %i."
                " Learner queue size: %i."
                " Other stats: (%s)",
                end_step,
                (end_step - start_step) /
                (timeit.default_timer() - start_time),
                inference_batcher.size(),
                learner_queue.size(),
                ", ".join(f"{key} = {format_value(value)}"
                          for key, value in stats.items()),
            )
    except KeyboardInterrupt:
        pass  # Close properly.
    else:
        logging.info("Learning finished after %i steps.", stats["step"])
        checkpoint()

    # Done with learning. Stop all the ongoing work.
    inference_batcher.close()
    learner_queue.close()

    replay_queue.close()

    actorpool_thread.join()

    d_learner.join()

    for t in threads:
        t.join()
コード例 #8
0
def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    terms = flags.xpid.split("-")
    if len(terms) == 3:
        group = terms[0] + "-" + terms[1]
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if (not flags.disable_cuda) and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda:" + str(torch.cuda.current_device()))
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    step, stats = 0, {}

    env = create_gymenv(flags)
    actor_flags = flags
    try:
        checkpoint = torch.load(checkpointpath, map_location=flags.device)
        step = checkpoint["step"]
    except Exception as e:
        print(e)

    model = create_model(flags, env).to(device=flags.device)
    try:
        model.load_state_dict(checkpoint["model_state_dict"])
    except Exception as e:
        print(e)

    if flags.agent in ["CNN"]:
        buffers = create_buffers(
            flags,
            env.observation_space.spaces["image"].shape,
            env.action_space.n,
            flags.unroll_length,
            flags.num_buffers,
            img_shape=env.observation_space.spaces["image"].shape)
        actor_buffers = create_buffers(
            flags,
            env.observation_space.spaces["image"].shape,
            env.action_space.n,
            0,
            flags.num_actors,
            img_shape=env.observation_space.spaces["image"].shape)
    elif flags.agent in ["NLM", "KBMLP", "GCN"]:
        buffers = create_buffers(
            flags,
            env.obs_shape,
            model.num_actions,
            flags.unroll_length,
            flags.num_buffers,
            img_shape=env.observation_space.spaces["image"].shape)
        actor_buffers = create_buffers(
            flags,
            env.obs_shape,
            model.num_actions,
            0,
            flags.num_actors,
            img_shape=env.observation_space.spaces["image"].shape)
    else:
        raise ValueError()

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    actor_model_queues = [ctx.SimpleQueue() for _ in range(flags.num_actors)]
    actor_env_queues = [ctx.SimpleQueue() for _ in range(flags.num_actors)]

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(actor_flags, create_gymenv(flags), i, free_queue, full_queue,
                  buffers, actor_buffers, actor_model_queues,
                  actor_env_queues),
        )
        actor.start()
        actor_processes.append(actor)

    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    logger = logging.getLogger("logfile")
    if flags.mode == "imitate":
        stat_keys = [
            "total_loss",
            "accuracy",
            "mean_episode_return",
        ]
    else:
        stat_keys = [
            "total_loss",
            "mean_episode_return",
            "pg_loss",
            "baseline_loss",
            "entropy_loss",
        ]
    logger.info("# Step\t%s", "\t".join(stat_keys))
    finish = False

    def batch_and_inference():
        nonlocal finish
        while not all(finish):
            indexes = []
            for i in range(flags.num_actors):
                indexes.append(actor_model_queues[i].get())
            batch = get_inference_batch(flags, actor_buffers)
            with torch.no_grad():
                agent_output = model(batch)

            for index in indexes:
                for key in agent_output:
                    actor_buffers[key][index][0] = agent_output[key][0, index]
            for i in range(flags.num_actors):
                actor_env_queues[i].put(None)

    finish = [False for _ in range(flags.num_learner_threads)]

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats, finish
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                timings,
            )
            stats = learn(flags, model, batch, optimizer, scheduler)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T * B

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())
        finish[i] = True

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    thread = threading.Thread(target=batch_and_inference,
                              name="batch-and-inference")
    thread.start()
    threads.append(thread)
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(target=batch_and_learn,
                                  name="batch-and-learn-%d" % i,
                                  args=(i, ))
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
                "step": step
            },
            checkpointpath,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        while step < flags.total_steps:
            start_step = step
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            sps = (step - start_step) / (timer() - start_time)
            total_loss = stats.get("total_loss", float("inf"))
            if stats.get("episode_returns", None):
                mean_return = ("Return per episode: %.1f. " %
                               stats["mean_episode_return"])
            else:
                mean_return = ""
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
                step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for i in range(flags.num_actors):
            free_queue.put(None)
            actor_env_queues[i].put("exit")

        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    plogger.close()