示例#1
0
def train_gpu_sync(args, env, model):
    # initiate subprocesses
    print('Creating %d processes' % args.nb_procs)
    workers = [
        Sampler(env, tasks, results, args.max_timesteps)
        for i in range(args.nb_procs)
    ]
    for w in workers:
        w.start()

    # logging variables
    dt = datetime.now().strftime("%Y.%m.%d_%H:%M:%S")
    writer = SummaryWriter(log_dir=os.path.join(args.artifact_path, 'runs/' +
                                                args.name + '_' + dt))
    save_dir = os.path.join(args.artifact_path,
                            'saves/' + args.name + '_' + dt)
    os.makedirs(save_dir, exist_ok=True)
    initialize_logger(save_dir)
    logging.info(model)

    sample_count = 0
    episode_count = 0
    save_counter = 0
    log_counter = 0

    running_length = 0
    running_reward = 0
    running_main_reward = 0

    memory = Memory()
    memories = [Memory() for _ in range(args.nb_procs)]
    rewbuffer_env = deque(maxlen=100)
    molbuffer_env = deque(maxlen=1000)
    # training loop
    i_episode = 0
    while i_episode < args.max_episodes:
        logging.info("\n\ncollecting rollouts")
        for i in range(args.nb_procs):
            tasks.put((i, None, True))
        tasks.join()
        # unpack results
        states = [None] * args.nb_procs
        done_idx = []
        notdone_idx, candidates, batch_idx = [], [], []
        for i in range(args.nb_procs):
            index, state, cands, done = results.get()

            states[index] = state

            notdone_idx.append(index)
            candidates.append(cands)
            batch_idx.extend([index] * len(cands))
        while True:
            # action selections (for not done)
            if len(notdone_idx) > 0:
                states_emb, candidates_emb, action_logprobs, actions = model.select_action(
                    mols_to_pyg_batch([
                        Chem.MolFromSmiles(states[idx]) for idx in notdone_idx
                    ],
                                      model.emb_3d,
                                      device=model.device),
                    mols_to_pyg_batch([
                        Chem.MolFromSmiles(item) for sublist in candidates
                        for item in sublist
                    ],
                                      model.emb_3d,
                                      device=model.device), batch_idx)
                if not isinstance(actions, list):
                    action_logprobs = [action_logprobs]
                    actions = [actions]
            else:
                if sample_count >= args.update_timesteps:
                    break

            for i, idx in enumerate(notdone_idx):
                tasks.put((idx, candidates[i][actions[i]], False))
                cands = [
                    data for j, data in enumerate(candidates_emb)
                    if batch_idx[j] == idx
                ]

                memories[idx].states.append(states_emb[i])
                memories[idx].candidates.append(cands)
                memories[idx].states_next.append(cands[actions[i]])
                memories[idx].actions.append(actions[i])
                memories[idx].logprobs.append(action_logprobs[i])
            for idx in done_idx:
                if sample_count >= args.update_timesteps:
                    tasks.put((None, None, True))
                else:
                    tasks.put((idx, None, True))
            tasks.join()
            # unpack results
            states = [None] * args.nb_procs
            new_done_idx = []
            new_notdone_idx, candidates, batch_idx = [], [], []
            for i in range(args.nb_procs):
                index, state, cands, done = results.get()

                if index is not None:
                    states[index] = state
                if done:
                    new_done_idx.append(index)
                else:
                    new_notdone_idx.append(index)
                    candidates.append(cands)
                    batch_idx.extend([index] * len(cands))
            # get final rewards (for previously not done but now done)
            nowdone_idx = [idx for idx in notdone_idx if idx in new_done_idx]
            stillnotdone_idx = [
                idx for idx in notdone_idx if idx in new_notdone_idx
            ]
            if len(nowdone_idx) > 0:
                main_rewards = get_main_reward(
                    [Chem.MolFromSmiles(states[idx]) for idx in nowdone_idx],
                    reward_type=args.reward_type,
                    args=args)
                if not isinstance(main_rewards, list):
                    main_rewards = [main_rewards]

            for i, idx in enumerate(nowdone_idx):
                main_reward = main_rewards[i]

                i_episode += 1
                running_reward += main_reward
                running_main_reward += main_reward
                writer.add_scalar("EpMainRew", main_reward, i_episode - 1)
                rewbuffer_env.append(main_reward)
                molbuffer_env.append((states[idx], main_reward))
                writer.add_scalar("EpRewEnvMean", np.mean(rewbuffer_env),
                                  i_episode - 1)

                memories[idx].rewards.append(main_reward)
                memories[idx].terminals.append(True)
            for idx in stillnotdone_idx:
                running_reward += 0

                memories[idx].rewards.append(0)
                memories[idx].terminals.append(False)
            # get innovation rewards
            if (args.iota > 0
                    and i_episode > args.innovation_reward_episode_delay
                    and i_episode < args.innovation_reward_episode_cutoff):
                if len(notdone_idx) > 0:
                    inno_rewards = model.get_inno_reward(
                        mols_to_pyg_batch([
                            Chem.MolFromSmiles(states[idx])
                            for idx in notdone_idx
                        ],
                                          model.emb_3d,
                                          device=model.device))
                    if not isinstance(inno_rewards, list):
                        inno_rewards = [inno_rewards]

                for i, idx in enumerate(notdone_idx):
                    inno_reward = args.iota * inno_rewards[i]

                    running_reward += inno_reward

                    memories[idx].rewards[-1] += inno_reward

            sample_count += len(notdone_idx)
            episode_count += len(nowdone_idx)
            running_length += len(notdone_idx)

            done_idx = new_done_idx
            notdone_idx = new_notdone_idx

        for m in memories:
            memory.extend(m)
            m.clear()

        # update model
        logging.info("\nupdating model @ episode %d..." % i_episode)
        model.update(memory)
        memory.clear()

        save_counter += episode_count
        log_counter += episode_count

        # stop training if avg_reward > solved_reward
        if np.mean(rewbuffer_env) > args.solved_reward:
            logging.info("########## Solved! ##########")
            save_DGAPN(
                model,
                os.path.join(save_dir,
                             'DGAPN_continuous_solved_{}.pt'.format('test')))
            break

        # save every 500 episodes
        if save_counter >= args.save_interval:
            save_DGAPN(
                model,
                os.path.join(save_dir, '{:05d}_dgapn.pt'.format(i_episode)))
            deque_to_csv(molbuffer_env, os.path.join(save_dir,
                                                     'mol_dgapn.csv'))
            save_counter = 0

        # save running model
        save_DGAPN(model, os.path.join(save_dir, 'running_dgapn.pt'))

        if log_counter >= args.log_interval:
            logging.info(
                'Episode {} \t Avg length: {} \t Avg reward: {:5.3f} \t Avg main reward: {:5.3f}'
                .format(i_episode, running_length / log_counter,
                        running_reward / log_counter,
                        running_main_reward / log_counter))

            running_length = 0
            running_reward = 0
            running_main_reward = 0
            log_counter = 0

        episode_count = 0
        sample_count = 0

    close_logger()
    writer.close()
    # Add a poison pill for each process
    for i in range(args.nb_procs):
        tasks.put(None)
    tasks.join()
示例#2
0
def train_cpu_async(args, env, model):
    # initiate subprocesses
    print('Creating %d processes' % args.nb_procs)
    workers = [Sampler(args, env, tasks, results,
                args.max_episodes, args.max_timesteps, args.update_timesteps) for i in range(args.nb_procs)]
    for w in workers:
        w.start()

    # logging variables
    dt = datetime.now().strftime("%Y.%m.%d_%H:%M:%S")
    writer = SummaryWriter(log_dir=os.path.join(args.artifact_path, 'runs/' + args.name + '_' + dt))
    save_dir = os.path.join(args.artifact_path, 'saves/' + args.name + '_' + dt)
    os.makedirs(save_dir, exist_ok=True)
    initialize_logger(save_dir)
    logging.info(model)

    save_counter = 0
    log_counter = 0

    running_length = 0
    running_reward = 0
    running_main_reward = 0

    memory = Memory()
    log = Log()
    rewbuffer_env = deque(maxlen=100)
    molbuffer_env = deque(maxlen=1000)
    # training loop
    i_episode = 0
    while i_episode < args.max_episodes:
        logging.info("\n\ncollecting rollouts")
        model.to_device(torch.device("cpu"))
        # Enqueue jobs
        for i in range(args.nb_procs):
            tasks.put(Task(model.state_dict()))
        # Wait for all of the tasks to finish
        tasks.join()
        # Start unpacking results
        for i in range(args.nb_procs):
            result = results.get()
            m, l = result()
            memory.extend(m)
            log.extend(l)

        i_episode += episode_count.value
        model.to_device(args.device)

        # log results
        for i in reversed(range(episode_count.value)):
            running_length += log.ep_lengths[i]
            running_reward += log.ep_rewards[i]
            running_main_reward += log.ep_main_rewards[i]
            rewbuffer_env.append(log.ep_main_rewards[i])
            molbuffer_env.append(log.ep_mols[i])
            writer.add_scalar("EpMainRew", log.ep_main_rewards[i], i_episode - 1)
            writer.add_scalar("EpRewEnvMean", np.mean(rewbuffer_env), i_episode - 1)
        log.clear()

        # update model
        logging.info("\nupdating model @ episode %d..." % i_episode)
        model.update(memory)
        memory.clear()

        save_counter += episode_count.value
        log_counter += episode_count.value

        # stop training if avg_reward > solved_reward
        if np.mean(rewbuffer_env) > args.solved_reward:
            logging.info("########## Solved! ##########")
            save_DGAPN(model, os.path.join(save_dir, 'DGAPN_continuous_solved_{}.pt'.format('test')))
            break

        # save every 500 episodes
        if save_counter >= args.save_interval:
            save_DGAPN(model, os.path.join(save_dir, '{:05d}_dgapn.pt'.format(i_episode)))
            deque_to_csv(molbuffer_env, os.path.join(save_dir, 'mol_dgapn.csv'))
            save_counter = 0

        # save running model
        save_DGAPN(model, os.path.join(save_dir, 'running_dgapn.pt'))

        if log_counter >= args.log_interval:
            logging.info('Episode {} \t Avg length: {} \t Avg reward: {:5.3f} \t Avg main reward: {:5.3f}'.format(
                i_episode, running_length/log_counter, running_reward/log_counter, running_main_reward/log_counter))

            running_reward = 0
            running_main_reward = 0
            running_length = 0
            log_counter = 0

        episode_count.value = 0
        sample_count.value = 0

    close_logger()
    writer.close()
    # Add a poison pill for each process
    for i in range(args.nb_procs):
        tasks.put(None)
    tasks.join()
示例#3
0
def train_serial(args, env, model):
    # logging variables
    dt = datetime.now().strftime("%Y.%m.%d_%H:%M:%S")
    writer = SummaryWriter(log_dir=os.path.join(args.artifact_path, 'runs/' +
                                                args.name + '_' + dt))
    save_dir = os.path.join(args.artifact_path,
                            'saves/' + args.name + '_' + dt)
    os.makedirs(save_dir, exist_ok=True)
    initialize_logger(save_dir)
    logging.info(model)

    time_step = 0

    running_length = 0
    running_reward = 0
    running_main_reward = 0

    memory = Memory()
    rewbuffer_env = deque(maxlen=100)
    molbuffer_env = deque(maxlen=1000)
    # training loop
    for i_episode in range(1, args.max_episodes + 1):
        if time_step == 0:
            logging.info("\n\ncollecting rollouts")
        state, candidates, done = env.reset()

        for t in range(args.max_timesteps):
            time_step += 1
            # Running policy:
            state_emb, candidates_emb, action_logprob, action = model.select_action(
                mols_to_pyg_batch(state, model.emb_3d, device=model.device),
                mols_to_pyg_batch(candidates,
                                  model.emb_3d,
                                  device=model.device))
            memory.states.append(state_emb[0])
            memory.candidates.append(candidates_emb)
            memory.states_next.append(candidates_emb[action])
            memory.actions.append(action)
            memory.logprobs.append(action_logprob)

            state, candidates, done = env.step(action)

            reward = 0
            if (t == (args.max_timesteps - 1)) or done:
                main_reward = get_main_reward(state,
                                              reward_type=args.reward_type,
                                              args=args)[0]
                reward = main_reward
                running_main_reward += main_reward
                done = True
            if (args.iota > 0
                    and i_episode > args.innovation_reward_episode_delay
                    and i_episode < args.innovation_reward_episode_cutoff):
                inno_reward = model.get_inno_reward(
                    mols_to_pyg_batch(state, model.emb_3d,
                                      device=model.device))
                reward += inno_reward
            running_reward += reward

            # Saving rewards and terminals:
            memory.rewards.append(reward)
            memory.terminals.append(done)

            if done:
                break

        # update if it's time
        if time_step >= args.update_timesteps:
            logging.info("\nupdating model @ episode %d..." % i_episode)
            time_step = 0
            model.update(memory)
            memory.clear()

        writer.add_scalar("EpMainRew", main_reward, i_episode - 1)
        rewbuffer_env.append(main_reward)  # reward
        molbuffer_env.append((Chem.MolToSmiles(state), main_reward))
        running_length += (t + 1)

        # write to Tensorboard
        writer.add_scalar("EpRewEnvMean", np.mean(rewbuffer_env),
                          i_episode - 1)

        # stop training if avg_reward > solved_reward
        if np.mean(rewbuffer_env) > args.solved_reward:
            logging.info("########## Solved! ##########")
            save_DGAPN(
                model,
                os.path.join(save_dir,
                             'DGAPN_continuous_solved_{}.pt'.format('test')))
            break

        # save every save_interval episodes
        if (i_episode - 1) % args.save_interval == 0:
            save_DGAPN(
                model,
                os.path.join(save_dir, '{:05d}_dgapn.pt'.format(i_episode)))
            deque_to_csv(molbuffer_env, os.path.join(save_dir,
                                                     'mol_dgapn.csv'))

        # save running model
        save_DGAPN(model, os.path.join(save_dir, 'running_dgapn.pt'))

        # logging
        if i_episode % args.log_interval == 0:
            logging.info(
                'Episode {} \t Avg length: {} \t Avg reward: {:5.3f} \t Avg main reward: {:5.3f}'
                .format(i_episode, running_length / args.log_interval,
                        running_reward / args.log_interval,
                        running_main_reward / args.log_interval))

            running_length = 0
            running_reward = 0
            running_main_reward = 0

    close_logger()
    writer.close()