Beispiel #1
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 ob_placeholder=None,
                 trainable=True,
                 args=None):

        self._base_dir = init_path.get_abs_base_dir()
        self._use_ppo = True
        self._ppo_clip = args.ppo_clip
        self._output_size = 1

        super(tf_baseline_network,
              self).__init__(session=session,
                             name_scope=name_scope,
                             input_size=input_size,
                             output_size=self._output_size,
                             ob_placeholder=ob_placeholder,
                             trainable=trainable,
                             build_network_now=True,
                             define_std=False,
                             is_baseline=True,
                             args=args)

        self._build_train_placeholders()

        self._build_loss()
Beispiel #2
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 output_size,
                 weight_init_methods='orthogonal',
                 ob_placeholder=None,
                 trainable=True,
                 build_network_now=True,
                 is_baseline=False,
                 placeholder_list=None,
                 args=None
                 ):
        '''
            @input: the same as the ones defined in "policy_network"
        '''
        self._shared_network = args.shared_network
        self._node_update_method = args.node_update_method

        policy_network.__init__(
            self,
            session,
            name_scope,
            input_size,
            output_size,
            ob_placeholder=ob_placeholder,
            trainable=trainable,
            build_network_now=False,
            define_std=True,
            is_baseline=False,
            args=args
        )

        self._base_dir = init_path.get_abs_base_dir()
        self._root_connection_option = args.root_connection_option
        self._num_prop_steps = args.gnn_num_prop_steps
        self._init_method = weight_init_methods
        self._gnn_node_option = args.gnn_node_option
        self._gnn_output_option = args.gnn_output_option
        self._gnn_embedding_option = args.gnn_embedding_option
        self._is_baseline = is_baseline
        self._placeholder_list = placeholder_list

        # parse the network shape and do validation check
        self._network_shape = args.network_shape
        self._hidden_dim = args.gnn_node_hidden_dim
        self._input_feat_dim = args.gnn_input_feat_dim
        self._input_obs = ob_placeholder
        self._seed = args.seed
        self._npr = np.random.RandomState(args.seed)

        assert self._input_feat_dim == self._hidden_dim
        logger.info('Network shape is {}'.format(self._network_shape))

        if build_network_now:
            self._build_model()
            if self._shared_network:
                # build the baseline loss and placeholder
                self._build_baseline_train_placeholders()
                self._build_baseline_loss()
Beispiel #3
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 output_size,
                 ob_placeholder=None,
                 trainable=True,
                 build_network_now=True,
                 define_std=True,
                 is_baseline=False,
                 args=None):
        '''
            @input:
                @ob_placeholder:
                    if this placeholder is not given, we will make one in this
                    class.

                @trainable:
                    If it is set to true, then the policy weights will be
                    trained. It is useful when the class is a subnet which
                    is not trainable

        '''
        self._session = session
        self._name_scope = name_scope

        self._input_size = input_size
        self._output_size = output_size
        self._base_dir = init_path.get_abs_base_dir()
        self._is_baseline = is_baseline

        self._input = ob_placeholder
        self._trainable = trainable

        self._define_std = define_std

        self._task_name = args.task_name
        self._network_shape = args.network_shape
        self._npr = np.random.RandomState(args.seed)
        self.args = args

        if build_network_now:
            with tf.get_default_graph().as_default():
                tf.set_random_seed(args.seed)
                self._build_model()
Beispiel #4
0
    def __init__(self,
                 session,
                 name_scope,
                 input_size,
                 placeholder_list,
                 weight_init_methods='orthogonal',
                 ob_placeholder=None,
                 trainable=True,
                 build_network_now=True,
                 args=None):

        root_connection_option = args.root_connection_option
        root_connection_option = root_connection_option.replace('Rn', 'Ra')
        root_connection_option = root_connection_option.replace('Rb', 'Ra')
        assert 'Rb' in root_connection_option or \
            'Ra' in root_connection_option, \
            logger.error(
                'Root connection option {} invalid for baseline'.format(
                    root_connection_option
                )
            )
        self._base_dir = init_path.get_abs_base_dir()

        GGNN.__init__(self,
                      session=session,
                      name_scope=name_scope,
                      input_size=input_size,
                      output_size=1,
                      weight_init_methods=weight_init_methods,
                      ob_placeholder=ob_placeholder,
                      trainable=trainable,
                      build_network_now=build_network_now,
                      is_baseline=True,
                      placeholder_list=placeholder_list,
                      args=args)

        self._build_train_placeholders()
        self._build_loss()
import argparse
import os
import num2words
from environments import centipede_generator
from environments import snake_generator
from environments import reacher_generator
from tool import init_path

TASK_DICT = {
    'Centipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14] + [20, 30, 40, 50],
    'CpCentipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14],
    'Reacher': [0, 1, 2, 3, 4, 5, 6, 7],
    'Snake': list(range(3, 10)) + [10, 20, 40],
}
OUTPUT_BASE_DIR = os.path.join(init_path.get_abs_base_dir(), 'environments',
                               'assets')


def save_xml_files(model_names, xml_number, xml_contents):
    # get the xml path ready
    number_str = num2words.num2words(xml_number)
    xml_names = model_names + number_str[0].upper() + number_str[1:] + '.xml'
    xml_file_path = os.path.join(OUTPUT_BASE_DIR, xml_names)

    # save the xml file
    f = open(xml_file_path, 'w')
    f.write(xml_contents)
    f.close()

Beispiel #6
0
#!/usr/bin/env python2
# -----------------------------------------------------------------------------
#   @brief:
#       In this function, we change the data from [batch_size, ob_dim] into
#       [batch_size * num_node, hidden_dim]
#   @author:
#       Tingwu Wang, Jul. 13th, 2017
# -----------------------------------------------------------------------------

import numpy as np
from tool import init_path
from util import logger
from six.moves import xrange

_ABS_BASE_PATH = init_path.get_abs_base_dir()


def construct_graph_input_feeddict(node_info,
                                   obs_n,
                                   receive_idx,
                                   send_idx,
                                   node_type_idx,
                                   inverse_node_type_idx,
                                   output_type_idx,
                                   inverse_output_type_idx,
                                   last_batch_size,
                                   request_data=['ob', 'idx']):
    '''
        @brief:
            @obs_n: the observation in the [batch_size * node, hidden_size]
import argparse
import tool.init_path as init_path
import os
import num2words
import environments.centipede_generator as centipede_generator
import environments.snake_generator as snake_generator
import environments.reacher_generator as reacher_generator

TASK_DICT = {
    'Centipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14] + [20, 30, 40, 50],
    'CpCentipede': [3, 5, 7] + [4, 6, 8, 10, 12, 14],
    'Reacher': [0, 1, 2, 3, 4, 5, 6, 7],
    'Snake': list(range(3, 10)) + [10, 20, 40],
}
OUTPUT_BASE_DIR = os.path.join(init_path.get_abs_base_dir(),
                               'environments', 'assets')


def save_xml_files(model_names, xml_number, xml_contents):
    # get the xml path ready
    number_str = num2words.num2words(xml_number)
    xml_names = model_names + number_str[0].upper() + number_str[1:] + '.xml'
    xml_file_path = os.path.join(OUTPUT_BASE_DIR, xml_names)

    # save the xml file
    f = open(xml_file_path, 'w')
    f.write(xml_contents)
    f.close()