def __init__(self,
                 dataset,
                 batch_size=None,
                 shuffle=False,
                 sampler=None,
                 last_batch=None,
                 collate_fn=_batchify,
                 batch_sampler=None):
        self._dataset = dataset
        self.collate_fn = collate_fn
        if batch_sampler is None:
            if batch_size is None:
                raise ValueError("batch_size must be specified unless " \
                                 "batch_sampler is specified")
            if sampler is None:
                if shuffle:
                    sampler = _sampler.RandomSampler(len(dataset))
                else:
                    sampler = _sampler.SequentialSampler(len(dataset))
            elif shuffle:
                raise ValueError(
                    "shuffle must not be specified if sampler is specified")

            batch_sampler = _sampler.BatchSampler(
                sampler, batch_size, last_batch if last_batch else 'keep')
        elif batch_size is not None or shuffle or sampler is not None or \
                last_batch is not None:
            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
                             "not be specified if batch_sampler is specified.")

        self._batch_sampler = batch_sampler
Example #2
0
def run_training(train_config):
    import os
    import logging
    logger = logging.getLogger(__name__)
    import gym
    import json
    import time

    import sampler
    import algorithm
    import model
    import environment
    import policy
    #############################
    # SETUP
    episode_results_path = os.path.join(
        train_config['odir'],
        'episode_results_run_%d.npy' % train_config['run'])

    make_directory(train_config['odir'])
    with open(
            os.path.join(train_config['odir'],
                         'train_config_run_%d.json' % train_config['run']),
            'w') as fp:
        json.dump(train_config, fp, sort_keys=True, indent=4)

    log_level = logging.DEBUG if train_config['debug'] else logging.INFO
    if not train_config['console']:
        logging.basicConfig(
            filename=os.path.join(train_config['odir'],
                                  'log_run_%d.log' % train_config['run']),
            level=log_level,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        )
    else:
        logging.basicConfig(level=log_level)

    ###############################
    # MAKE NET AND POLICY
    env = gym.make(train_config['env'])
    num_inputs = env.observation_space.shape[0]
    num_actions = env.action_space.shape[0]
    if train_config['seed'] is not None:
        random.seed(train_config['seed'])
        ts = random.randint(1, 1e8)
        torch.manual_seed(ts)

    log_str = '\r\n###################################################\r\n' + \
              '\tEnvironment: %s\r\n' % train_config['env'] + \
              '\tAlgorithm: %s\r\n' % train_config['alg'] + \
              '\tPolicy Class: %s\r\n' % train_config['policy'] + \
              '\tNetwork Types: (%s,%s)\r\n' % (train_config['ac_types']['actor'],train_config['ac_types']['critic']) + \
              '\tNetwork Params: %s \r\n' % str(train_config['ac_kwargs']) + \
              '\tN, Total Updates, Save Interval: (%d,%d,%d) \r\n' % (train_config['N'],train_config['num_updates'],train_config['save_interval']) + \
              '###################################################'
    logger.info(log_str)

    actor_net = getattr(model, train_config['ac_types']['actor'])(
        num_inputs, num_actions, **train_config['ac_kwargs'])
    critic_net = getattr(model, train_config['ac_types']['critic'])(
        num_inputs, **train_config['ac_kwargs'])
    plc = None
    try:
        plc_class = getattr(policy, train_config['policy'])
        plc = plc_class(actor_net)
    except AttributeError as e:
        raise RuntimeError('Algorithm "%s" not found' % train_config['policy'])

    ###############################
    # CREATE ENVIRONMENT AND RUN
    algo = None
    try:
        algo_class = getattr(algorithm, train_config['alg'])
        algo = algo_class(plc, critic_net, train_config)
    except AttributeError as e:
        raise RuntimeError('Algorithm "%s" not found' % train_config['alg'])

    smp = sampler.BatchSampler(plc, **train_config)
    episode_results = np.array([]).reshape((0, 6))
    cur_update = 0
    finished_episodes = 0
    smp.reset()
    samples_per_update = train_config['N'] * train_config['num_env']
    start = time.time()
    while cur_update < train_config['num_updates']:
        batch, crs, trs, els = smp.sample()
        algo.update(batch)

        # save episode results
        for i, (cr, tr, el) in enumerate(zip(crs, trs, els)):
            finished_episodes += 1
            total_samples = cur_update * samples_per_update
            # stores: total_updates, total_episodes, total_samples, current_episode_length, current_total_reward, current_cumulative_reward
            episode_results = np.concatenate(
                (episode_results,
                 np.array([
                     cur_update, finished_episodes, total_samples, el, tr, cr
                 ],
                          ndmin=2)),
                axis=0)
            np.save(episode_results_path, episode_results)
            logger.info(
                'Update Number: %06d, Finished Episode: %04d ---  Length: %.3f, TR: %.3f, CDR: %.3f'
                % (cur_update, finished_episodes, el, tr, cr))

        # checkpoint
        if cur_update % train_config['save_interval'] == 0:
            plc.save_model(
                os.path.join(
                    train_config['odir'], 'model_update_%06d_run_%d.pt' %
                    (cur_update, train_config['run'])))
        cur_update += 1

    end = time.time()
    print(end - start)
Example #3
0
    def __init__(
            self,
            dataset,
            batch_size=None,
            shuffle=False,
            sampler=None,
            last_batch=None,
            batch_sampler=None,
            batchify_fn=None,
            inter_batchify_fn=None,  # for internal dataloader
            part_num=20,  # for part loader
            num_workers=0):
        self._dataset = dataset

        if batch_sampler is None:
            if batch_size is None:
                raise ValueError("batch_size must be specified unless " \
                                 "batch_sampler is specified")
            if sampler is None:
                if shuffle:
                    sampler = _sampler.RandomSampler(len(dataset))
                else:
                    sampler = _sampler.SequentialSampler(len(dataset))
            elif shuffle:
                raise ValueError(
                    "shuffle must not be specified if sampler is specified")

            batch_sampler = _sampler.BatchSampler(
                sampler, batch_size, last_batch if last_batch else 'keep')
        elif batch_size is not None or shuffle or sampler is not None or \
                last_batch is not None:
            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
                             "not be specified if batch_sampler is specified.")

        self._batch_sampler = batch_sampler
        self._num_workers = num_workers
        if batchify_fn is None:
            #if num_workers > 0:
            #    self.batchify_fn = default_mp_batchify_fn
            #else:
            #    self.batchify_fn = default_batchify_fn
            raise Exception('no batchify_fn is specified')
        else:
            self.batchify_fn = batchify_fn

        self.inter_batchify_fn = inter_batchify_fn

        self.batch_size = batch_size
        self.shuffle = shuffle

        self.num = None  #num
        if self.num is None:
            self.num = len(self._dataset) / self.batch_size

        self.cache_n = self._num_workers  #cache_n
        self.cache_i = 0
        self.cache_num = None  #cache_num
        if self.cache_num is None:
            self.init_cache_num()

        self.data, self.label = self._dataset[0]

        self.batch_data_shape = (self.batch_size, ) + self.data.shape
        self.batch_label_shape = (self.batch_size, )

        self.part_num = part_num  # num of batches in shared memory once time

        self.init_data_shm((self.part_num, ) + self.batch_data_shape,
                           np.float32, None)
        self.init_label_shm((self.part_num, ) + self.batch_label_shape,
                            np.float32, None)

        self.init_qs()
Example #4
0
    # MAKE NET AND POLICY
    critic_net = FFNet(in_size=2, out_size=1)
    actor_net = FFNet(in_size=2, out_size=2)
    plc = None
    if train_config['policy'] == 'angular':
        plc = policy.AngularPolicy(actor_net, train_config['sigma'])
    elif train_config['policy'] == 'gauss':
        plc = policy.GaussianPolicy(actor_net, train_config['sigma'])
    else:
        raise RuntimeError('Not a valid policy: %s' % train_config['policy'])

    ###############################
    # CREATE ENVIRONMENT AND RUN
    algo = A2C(plc, critic_net, train_config['lr'], train_config['gamma'])

    sampler = sampler.BatchSampler(plc, **train_config)
    cumulative_rewards = np.array([]).reshape((0, 3))
    cur_update = 0
    finished_episodes = 0
    sampler.reset()

    while cur_update < train_config['num_updates']:
        batch, terminal = sampler.sample()
        algo.update(batch, terminal)
        cr = sampler.cumulative_reward

        # save cumulative rewards
        for i, t in enumerate(terminal):
            if t:
                finished_episodes += 1
                cumulative_rewards = np.concatenate(