Exemplo n.º 1
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 ob_placeholder=None,
                 trainable=True,
                 args=None):

        self._base_dir = init_path.get_abs_base_dir()
        self._use_ppo = True
        self._ppo_clip = args.ppo_clip
        self._output_size = 1

        super(tf_baseline_network,
              self).__init__(session=session,
                             name_scope=name_scope,
                             input_size=input_size,
                             output_size=self._output_size,
                             ob_placeholder=ob_placeholder,
                             trainable=trainable,
                             build_network_now=True,
                             define_std=False,
                             is_baseline=True,
                             args=args)

        self._build_train_placeholders()

        self._build_loss()
Exemplo n.º 2
0
    def __init__(self, args, rand_seed, monitor, width=480, height=480):
        self.width = width
        self.height = height
        self.task = args.task
        self.args = args
        assert 'fish3d' in self.task
        self.is_evo = 'evo' in self.task
        if 'easyswim' in self.task:
            self.target_angle = args.fish_target_angle

        from dm_control import suite
        self.env = suite.load(
            domain_name='fish', task_name='swim',
            task_kwargs={'random': rand_seed}
        )
        self._base_path = init_path.get_abs_base_dir()

        self.load_xml(os.path.join(self._base_path, 'env', 'assets/fish3d.xml'))
        self.set_get_observation()  # overwrite the original get_ob function
        self.set_get_reward()  # overwrite the original reward function
        self._JOINTS = ['tail1',
                        'tail_twist',
                        'tail2',
                        'finright_roll',
                        'finright_pitch',
                        'finleft_roll',
                        'finleft_pitch']

        # save the video
        self._monitor = monitor
        self._current_episode = 0
        if self._monitor:
            self.init_save(args)
Exemplo n.º 3
0
def main():
    parser = base_config.get_base_config()
    parser = ecco_config.get_ecco_config(parser)
    parser = dqn_transfer_config.get_dqn_transfer_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='ecco_ecco' + args.task,
                                time_str=args.exp_id)

    print('DQN_TRANSFER_MAIN.PY is Deprecated, do not use')
    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from trainer import dqn_transfer_trainer
    from runners import dqn_transfer_task_sampler
    from runners.workers import dqn_transfer_worker
    from policy import ecco_pretrain
    from policy import dqn_base, a2c_base
    from policy import ecco_transfer

    base_model = {'dqn': dqn_base, 'a2c': a2c_base}[args.base_policy]

    models = {
        'final': ecco_pretrain.model,
        'transfer': ecco_transfer.model,
        'base': base_model.model
    }

    pretrain_weights = None

    train(dqn_transfer_trainer.trainer, dqn_transfer_task_sampler,
          dqn_transfer_worker, models, args, pretrain_weights)
Exemplo n.º 4
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 output_size,
                 weight_init_methods='orthogonal',
                 ob_placeholder=None,
                 trainable=True,
                 build_network_now=True,
                 is_baseline=False,
                 placeholder_list=None,
                 args=None
                 ):
        '''
            @input: the same as the ones defined in "policy_network"
        '''
        self._shared_network = args.shared_network
        self._node_update_method = args.node_update_method

        policy_network.__init__(
            self,
            session,
            name_scope,
            input_size,
            output_size,
            ob_placeholder=ob_placeholder,
            trainable=trainable,
            build_network_now=False,
            define_std=True,
            is_baseline=False,
            args=args
        )

        self._base_dir = init_path.get_abs_base_dir()
        self._root_connection_option = args.root_connection_option
        self._num_prop_steps = args.gnn_num_prop_steps
        self._init_method = weight_init_methods
        self._gnn_node_option = args.gnn_node_option
        self._gnn_output_option = args.gnn_output_option
        self._gnn_embedding_option = args.gnn_embedding_option
        self._is_baseline = is_baseline
        self._placeholder_list = placeholder_list

        # parse the network shape and do validation check
        self._network_shape = args.network_shape
        self._hidden_dim = args.gnn_node_hidden_dim
        self._input_feat_dim = args.gnn_input_feat_dim
        self._input_obs = ob_placeholder
        self._seed = args.seed
        self._npr = np.random.RandomState(args.seed)

        assert self._input_feat_dim == self._hidden_dim
        logger.info('Network shape is {}'.format(self._network_shape))

        if build_network_now:
            self._build_model()
            if self._shared_network:
                # build the baseline loss and placeholder
                self._build_baseline_train_placeholders()
                self._build_baseline_loss()
Exemplo n.º 5
0
def pretrain(worker_trainer, models, args=None):
    logger.info('Pretraining starts at {}'.format(
        init_path.get_abs_base_dir()))

    worker_trainer_agent = make_joint_worker_trainer(worker_trainer.trainer,
                                                     models, args)

    weights = worker_trainer_agent.run()

    return weights
Exemplo n.º 6
0
    def __init__(self, env_name, rand_seed, maximum_length, misc_info):
        super(env, self).__init__(env_name, rand_seed, maximum_length,
                                  misc_info)
        self._base_path = init_path.get_abs_base_dir()
        self._env.env.penalty_for_step = 0.
        self.n_boxes = 3

        if 'easy' in self._env_name:
            self.n_boxes = 1

        self._last_reward = 0
        self._episode_reward = 0
Exemplo n.º 7
0
    def __init__(self, args, task_q, result_q,
                 name_scope='pruning_agent_policy'):
        super(pruning_agent, self).__init__(args, -1, -1, task_q, result_q,
                                            name_scope=name_scope)
        self.args = args
        self.base_path = init_path.get_abs_base_dir()
        self.task_q = task_q
        self.result_q = result_q
        self._name_scope = name_scope
        self.current_iteration = 0
        self._seed = 1234

        self.data_dict = {}  # species id: reward
Exemplo n.º 8
0
    def __init__(self,
                 args,
                 network_type,
                 task_queue,
                 result_queue,
                 name_scope='trainer'):
        # the base agent
        super(trainer, self).__init__(args=args,
                                      network_type=network_type,
                                      task_queue=task_queue,
                                      result_queue=result_queue,
                                      name_scope=name_scope)

        self._base_path = init_path.get_abs_base_dir()
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 output_size,
                 adj_matrix,
                 node_attr,
                 weight_init_methods='orthogonal',
                 is_rollout_agent=True,
                 args=None):

        self._node_update_method = args.node_update_method
        self.adj_matrix = adj_matrix
        self.node_attr = node_attr

        policy_network.__init__(self,
                                session,
                                name_scope,
                                input_size,
                                output_size,
                                ob_placeholder=None,
                                trainable=True,
                                build_network_now=False,
                                define_std=True,
                                args=args)

        self._base_dir = init_path.get_abs_base_dir()
        self._root_connection_option = args.root_connection_option
        self._num_prop_steps = args.gnn_num_prop_steps
        self._init_method = weight_init_methods
        self._gnn_node_option = args.gnn_node_option
        self._gnn_output_option = args.gnn_output_option
        self._gnn_embedding_option = args.gnn_embedding_option

        self.is_rollout_agent = is_rollout_agent

        self._nstep = 1 if self.is_rollout_agent else self._num_prop_steps

        # parse the network shape and do validation check
        self._network_shape = args.network_shape
        self._hidden_dim = args.gnn_node_hidden_dim
        self._input_feat_dim = args.gnn_input_feat_dim
        self._seed = args.seed
        self._npr = np.random.RandomState(args.seed)

        assert self._input_feat_dim == self._hidden_dim

        self._build_model()
Exemplo n.º 10
0
    def __init__(self, models, args, scope='trainer', environment_cache=None):
        self.args = args
        self._name_scope = scope
        self._network_type = models

        # the base agent
        self._base_path = init_path.get_abs_base_dir()

        # used to save the checkpoint files
        self._npr = np.random.RandomState(args.seed)
        self.data_dict = {}
        self._environments_cache = environment_cache
        self.current_iteration = 0
        self.weights = None
        if environment_cache is None:
            self._environments_cache = []
Exemplo n.º 11
0
def train(trainer, sampler, worker, dynamics, policy, reward, args=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
    network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward}

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, network_type, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, network_type, args)
    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    totalsteps = 0
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        rollout_data = sampler_agent._rollout_with_workers()

        timer_dict['Generate Rollout'] = time.time()

        # step 2: train the weights for dynamics and policy network
        training_info = {}
        trainer_tasks.put((parallel_util.TRAIN_SIGNAL, {
            'data': rollout_data['data'],
            'training_info': training_info
        }))
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        sampler_agent.set_weights(training_return['network_weights'])
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        if totalsteps > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))
Exemplo n.º 12
0
def pretrain(worker_trainer, models, args=None,
        environments_cache=None):
    logger.info('Pretraining starts at {}'.format(
        init_path.get_abs_base_dir()))

    single_threaded_agent = make_single_threaded_agent(
        worker_trainer.trainer, models, args
    )

    if environments_cache is not None:
        single_threaded_agent.set_environments(
            environments_cache
        )
    
    weights, environments = single_threaded_agent.run()

    return weights, environments
Exemplo n.º 13
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 output_size,
                 ob_placeholder=None,
                 trainable=True,
                 build_network_now=True,
                 define_std=True,
                 is_baseline=False,
                 args=None
                 ):
        '''
            @input:
                @ob_placeholder:
                    if this placeholder is not given, we will make one in this
                    class.

                @trainable:
                    If it is set to true, then the policy weights will be
                    trained. It is useful when the class is a subnet which
                    is not trainable

        '''
        self._session = session
        self._name_scope = name_scope

        self._input_size = input_size
        self._output_size = output_size
        self._base_dir = init_path.get_abs_base_dir()
        self._is_baseline = is_baseline

        self._input = ob_placeholder
        self._trainable = trainable

        self._define_std = define_std

        self._task_name = args.task_name
        self._network_shape = args.network_shape
        self._npr = np.random.RandomState(args.seed)
        self.args = args

        if build_network_now:
            with tf.get_default_graph().as_default():
                tf.set_random_seed(args.seed)
                self._build_model()
Exemplo n.º 14
0
    def __init__(self, args, worker_proto, network_proto):
        self.args = args
        self._npr = np.random.RandomState(args.seed + 23333)
        self._observation_size, self._action_size, \
            self._action_distribution = \
            env_register.io_information(self.args.task)
        self._worker_type = worker_proto
        self._network_type = network_proto

        # init the multiprocess actors
        self._task_queue = multiprocessing.JoinableQueue()
        self._result_queue = multiprocessing.Queue()
        self._init_workers()
        self._build_env()
        self._base_path = init_path.get_abs_base_dir()

        self._current_iteration = 0
Exemplo n.º 15
0
    def __init__(self, args, rand_seed, monitor, width=480, height=480):
        self.width = width
        self.height = height
        task_name = dm_control_util.get_env_names(args.task)
        from dm_control import suite
        self.env = suite.load(domain_name=task_name[0],
                              task_name=task_name[1],
                              task_kwargs={'random': rand_seed})
        self._base_path = init_path.get_abs_base_dir()
        self.NUM_EPISODE_RECORED = NUM_EPISODE_RECORED
        self._is_dirname = True

        # save the video
        self._monitor = monitor
        self._current_episode = 0
        if self._monitor:
            self.init_save(args)
Exemplo n.º 16
0
    def __init__(self, models, args, scope='trainer'):
        self.args = args
        self._name_scope = scope
        self._network_type = models

        # the base agent
        self._base_path = init_path.get_abs_base_dir()

        # used to save the checkpoint files
        self.timesteps_so_far = 0
        self._npr = np.random.RandomState(args.seed)
        self.env = None
        self._current_env_idx = 0
        self._is_done = 0
        self._reset_flag = 0
        self._episodes = 0
        self.data_dict = {}
        self._environments_cache = []
Exemplo n.º 17
0
def main():
    parser = base_config.get_base_config()
    parser = ecco_config.get_ecco_config(parser)
    args = base_config.make_parser(parser)

    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='ecco_ecco' + args.task,
                                time_str=args.exp_id)

    print('Training starts at {}'.format(init_path.get_abs_base_dir()))
    from trainer import ecco_trainer
    from runners import task_sampler
    from runners.workers import base_worker
    from policy import ecco_pretrain

    train(ecco_trainer.trainer, task_sampler, base_worker, ecco_pretrain.model,
          args)
Exemplo n.º 18
0
    def __init__(self,
                 args,
                 network_type,
                 task_queue,
                 result_queue,
                 name_scope='trainer'):
        multiprocessing.Process.__init__(self)
        self.args = args
        self._name_scope = name_scope

        # the base agent
        self._base_path = init_path.get_abs_base_dir()

        # used to save the checkpoint files
        self._iteration = 0
        self._best_reward = -np.inf
        self._timesteps_so_far = 0
        self._npr = np.random.RandomState(args.seed)
        self._task_queue = task_queue
        self._result_queue = result_queue
        self._network_type = network_type
Exemplo n.º 19
0
    def __init__(self, load_path, elimination_rate, maximum_num_species, args):
        self.num_total_species = 0
        self.current_generation = 1
        self.elimination_rate = elimination_rate
        self.maximum_num_species = maximum_num_species
        self.args = args

        self.species = []
        self.base_path = os.path.join(
            init_path.get_abs_base_dir(), 'evolution_data',
            self.args.task + '_' + self.args.time_id
        )
        self.video_base_path = os.path.join(self.base_path, 'species_video/')
        if not os.path.exists(self.base_path):
            os.makedirs(self.base_path)
            os.makedirs(os.path.join(self.base_path, 'species_topology'))
            os.makedirs(os.path.join(self.base_path, 'species_data'))
            os.makedirs(self.video_base_path)

        if load_path is not None:
            self.load(load_path)
        self.gene_tree = {}
def get_candidates(args):
    # base_path for the base, candidate_list for the topology data
    # case one: plot all species, or one of the species
    #   XX/species_topology
    # case two: plot the top ranked_species
    #   XX/species_data
    # case three: plot the top ranked_species's video
    #   XX/species_video

    if args.file_name.endswith('.npy'):
        candidate_list = [args.file_name]
    else:
        candidate_list = glob.glob(os.path.join(args.file_name, '*.npy'))
        candidate_list = [
            i_candidate for i_candidate in candidate_list
            if 'rank_info' not in i_candidate
        ]

    if 'species_topology' in args.file_name:
        species_topology_list = candidate_list

    elif 'species_data' in args.file_name:
        species_topology_list = candidate_list
    else:
        assert 'species_video' in args.file_name
        species_topology_list = [
            os.path.join(
                os.path.dirname(i_candidate).replace('species_video',
                                                     'species_topology'),
                os.path.basename(i_candidate).split('_')[1] + '.npy')
            for i_candidate in candidate_list
        ]

    task = os.path.abspath(candidate_list[0]).split(
        init_path.get_abs_base_dir())[1].split('/')[2].split('_')[0]
    task = task.replace('/', '')

    return candidate_list, species_topology_list, task
Exemplo n.º 21
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 placeholder_list,
                 weight_init_methods='orthogonal',
                 ob_placeholder=None,
                 trainable=True,
                 build_network_now=True,
                 args=None):

        root_connection_option = args.root_connection_option
        root_connection_option = root_connection_option.replace('Rn', 'Ra')
        root_connection_option = root_connection_option.replace('Rb', 'Ra')
        assert 'Rb' in root_connection_option or \
            'Ra' in root_connection_option, \
            logger.error(
                'Root connection option {} invalid for baseline'.format(
                    root_connection_option
                )
            )
        self._base_dir = init_path.get_abs_base_dir()

        GGNN.__init__(self,
                      session=session,
                      name_scope=name_scope,
                      input_size=input_size,
                      output_size=1,
                      weight_init_methods=weight_init_methods,
                      ob_placeholder=ob_placeholder,
                      trainable=trainable,
                      build_network_now=build_network_now,
                      is_baseline=True,
                      placeholder_list=placeholder_list,
                      args=args)

        self._build_train_placeholders()
        self._build_loss()
Exemplo n.º 22
0
    def __init__(self,
                 args,
                 session,
                 name_scope,
                 initial_node_info,
                 seed=1234,
                 bayesian_op=False):
        '''
            @input: the same as the ones defined in "policy_network"
        '''
        self.bayesian_op = bayesian_op
        self.args = args
        self._num_prop_steps = self.args.gnn_num_prop_steps
        self._initial_node_info = self._node_info = initial_node_info
        self._seed = seed
        self._input_feat_dim = self._hidden_dim = self.args.gnn_node_hidden_dim
        self._network_shape = self.args.network_shape
        self._name_scope = name_scope

        self._init_method = 'orthogonal'
        self._node_update_method = 'GRU'

        self._build_model()
        self._base_path = init_path.get_abs_base_dir()
Exemplo n.º 23
0
 def __init__(self, env_name, rand_seed, maximum_length, misc_info):
     super(env, self).__init__(env_name, rand_seed, maximum_length,
                               misc_info)
     self._base_path = init_path.get_abs_base_dir()
Exemplo n.º 24
0
#!/usr/bin/env python2
# -----------------------------------------------------------------------------
#   @brief:
#       In this function, we change the data from [batch_size, ob_dim] into
#       [batch_size * num_node, hidden_dim]
#   @author:
#       Tingwu Wang, Jul. 13th, 2017
# -----------------------------------------------------------------------------

import numpy as np
import init_path
from util import logger
from six.moves import xrange

_ABS_BASE_PATH = init_path.get_abs_base_dir()


def construct_graph_input_feeddict(node_info,
                                   obs_n,
                                   receive_idx,
                                   send_idx,
                                   node_type_idx,
                                   inverse_node_type_idx,
                                   output_type_idx,
                                   inverse_output_type_idx,
                                   last_batch_size,
                                   request_data=['ob', 'idx']):
    '''
        @brief:
            @obs_n: the observation in the [batch_size * node, hidden_size]
Exemplo n.º 25
0
import argparse
import init_path
import os
import num2words
import centipede_generator
import snake_generator
import reacher_generator

TASK_DICT = {
    'Centipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14] + [20, 30, 40, 50],
    'CpCentipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14],
    'Reacher': [0, 1, 2, 3, 4, 5, 6, 7],
    'Snake': range(3, 10) + [10, 20, 40],
}
OUTPUT_BASE_DIR = os.path.join(init_path.get_abs_base_dir(), 'environments',
                               'assets')


def save_xml_files(model_names, xml_number, xml_contents):
    # get the xml path ready
    number_str = num2words.num2words(xml_number)
    xml_names = model_names + number_str[0].upper() + number_str[1:] + '.xml'
    xml_file_path = os.path.join(OUTPUT_BASE_DIR, xml_names)

    # save the xml file
    f = open(xml_file_path, 'w')
    f.write(xml_contents)
    f.close()

Exemplo n.º 26
0
# -----------------------------------------------------------------------------
#   @brief:
#       util function for assigning running_mean and embedding for the agent
#       Note these functions should be called by the evolutionary agent!
#       written by Tingwu Wang
# -----------------------------------------------------------------------------

import init_path
import numpy as np
from env import model_gen
from env import model_perturb
from lxml import etree
from copy import deepcopy
from graph_util import graph_data_util

PATH = init_path.get_abs_base_dir()
WEIGHT_NAMES = ['policy_logstd']  # 'policy_0/w:0', 'policy_output'
INHERIT_LIST = [
    'policy_weights', 'baseline_weights', 'running_mean_info', 'var_name',
    'node_info', 'observation_size', 'action_size', 'lr'
]
GENERATED_LIST = \
    ['xml_str', 'adj_matrix', 'node_attr', 'debug_info', 'PrtID',
     'SpcID', 'node_info']
INVALID_LIST = [
    'stats',
    'start_time',
    'rollout_time',
    'agent_id',
    'env_name',
    'rank_info',
Exemplo n.º 27
0
import argparse
import init_path
import os
import num2words
import environments.centipede_generator
import environments.snake_generator
import environments.reacher_generator

TASK_DICT = {
    'Centipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14] + [20, 30, 40, 50],
    'CpCentipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14],
    'Reacher': [0, 1, 2, 3, 4, 5, 6, 7],
    # 'Snake': range(3, 10) + [10, 20, 40],
}
OUTPUT_BASE_DIR = os.path.join(init_path.get_abs_base_dir(),
                               'environments', 'assets')


def save_xml_files(model_names, xml_number, xml_contents):
    # get the xml path ready
    number_str = num2words.num2words(xml_number)
    xml_names = model_names + number_str[0].upper() + number_str[1:] + '.xml'
    xml_file_path = os.path.join(OUTPUT_BASE_DIR, xml_names)

    # save the xml file
    f = open(xml_file_path, 'w')
    f.write(xml_contents)
    f.close()

Exemplo n.º 28
0
def train(trainer, sampler, worker, models,
          args=None, pretrain_dict = None,
          environments_cache=None):

    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
    
    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, models, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, models, args)

    if pretrain_dict is not None:
        pretrain_weights, environments_cache = \
            pretrain_dict['pretrain_fnc'](
                pretrain_dict['pretrain_thread'], models, args,
                environments_cache
            )

    else:
        pretrain_weights = environments_cache = None

    for key in pretrain_weights['base']:
        try:
            assert not np.array_equal(pretrain_weights['base'][key],
                                      init_weights['base'][key])
        except:
            print(key, pretrain_weights['base'][key], init_weights['base'][key])
        
    init_weights = init_weights \
        if pretrain_weights is None else pretrain_weights

    trainer_tasks.put(
       (parallel_util.TRAINER_SET_WEIGHTS,
       init_weights)
    )
    trainer_tasks.join()

    sampler_agent.set_weights(init_weights)
    if environments_cache is not None:
        sampler_agent.set_environments(environments_cache)

        trainer_tasks.put(
            (parallel_util.TRAINER_SET_ENVIRONMENTS,
            environments_cache)
        )

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        training_info = {}
        rollout_info = {}

        training_info['train_model'] = 'final'
        rollout_info['rollout_model'] = 'final'
            
        if args.freeze_actor_final:
            training_info['train_net'] = 'manager'

        elif args.decoupled_managers:
            if (current_iteration % \
                (args.manager_updates + args.actor_updates)) \
                < args.manager_updates:
                training_info['train_net'] = 'manager'

            else:
                training_info['train_net'] = 'actor'

        else:
            training_info['train_net'] = None
            
        rollout_data = \
            sampler_agent._rollout_with_workers(rollout_info)

        timer_dict['Generate Rollout'] = time.time()

        trainer_tasks.put(
            (parallel_util.TRAIN_SIGNAL,
             {'data': rollout_data['data'], 'training_info': training_info})
        )
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        weights = training_return['network_weights']
        for key in weights['base']:
            assert np.array_equal(weights['base'][key],
                init_weights['base'][key])
        sampler_agent.set_weights(weights)
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        if training_return['totalsteps'] > args.max_timesteps:
            trainer_tasks.put(
                parallel_util.SAVE_SIGNAL,
                {'net': 'final'}
            )

        #if totalsteps > args.max_timesteps:
        if training_return['totalsteps'] > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))
Exemplo n.º 29
0
from evolution import species_bank as sbank
from evolution import agent_bank as abank
from evolution import evo_analysis as analysis
from agent import pruning_agent
import multiprocessing

os.environ['DISABLE_MUJOCO_RENDERING'] = '1'  # disable headless rendering
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['MUJOCO_GL'] = 'osmesa'

if __name__ == '__main__':
    # read the configs and set the logging path
    args = get_config(evolution=True)
    args.use_nervenet = 1

    base_path = init_path.get_abs_base_dir()
    if args.write_log:
        logger.set_file_handler(path=args.output_dir,
                                prefix='mujoco_' + args.task,
                                time_str=args.time_id)

    if args.viz:
        visdom_util.visdom_initialize(args)
        if not args.mute_info:
            analysis.print_info(args)

    agent_bank = abank.agent_bank(args.num_threads + 1, args)
    species_bank = sbank.species_bank(args.speciesbank_path,
                                      args.elimination_rate,
                                      args.maximum_num_species, args)
Exemplo n.º 30
0
def train(trainer, sampler, worker, network_type, args=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, network_type, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, network_type, args)
    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        # step 1: collect rollout data
        rollout_data = \
            sampler_agent._rollout_with_workers()

        timer_dict['Generate Rollout'] = time.time()

        # step 2: train the weights for dynamics and policy network
        training_info = {}

        if args.pretrain_vae and current_iteration < args.pretrain_iterations:
            training_info['train_net'] = 'vae'

        elif args.decoupled_managers:
            if (current_iteration % \
                (args.manager_updates + args.actor_updates)) \
                < args.manager_updates:
                training_info['train_net'] = 'manager'

            else:
                training_info['train_net'] = 'actor'

        trainer_tasks.put((parallel_util.TRAIN_SIGNAL, {
            'data': rollout_data['data'],
            'training_info': training_info
        }))
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        sampler_agent.set_weights(training_return['network_weights'])
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        #if totalsteps > args.max_timesteps:
        if training_return['totalsteps'] > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))