Пример #1
0
    def reconnect(self):
        """Reconnects to the server and reloads configuration,
        modules and self.msg_parser.

        """

        self.send_data("QUIT :Reconnecting")
        self.close()
        self.config_reader = GeneralConfig()
        self.configuration = self.config_reader.read()
        self.msg_parser = parser.parser()
        self.registered = False
        self.first_msg = True
        self.channel_list = []
        bot_module_functions.populate_ircchannels(self)
        self.connect_to_irc()
Пример #2
0
    def __init__(self, env, name):
        # == Initialize ==
        self.name = name
        self.sess = tf.Session()
        self.record = None
        self.env = env

        self.actor = []  # hold all Actors
        self.critic = []  # hold all Critics
        self.actions_dims = []  # record the action split for gradient apply

        _CONFIG = GeneralConfig()
        self.update_every = _CONFIG.update_every
        self.train_freq = _CONFIG.train_freq

        # == Construct Network for Each Agent ==
        with tf.variable_scope(self.name):
            for agent_id in range(self.env.n):
                with tf.variable_scope(name + "_{}".format(agent_id)):
                    self.actor.append(
                        Actor(env, self.sess, name, agent_id, _CONFIG))
                    self.critic.append(
                        Critic(env, self.sess, name, agent_id, _CONFIG))

                    self.actions_dims.append(self.env.action_space[agent_id].n)

            # === Define summary ===
            self.reward = [
                tf.placeholder(tf.float32,
                               shape=None,
                               name="Agent_{}_reward".format(i))
                for i in range(self.env.n)
            ]
            self.reward_op = [
                tf.summary.scalar('Agent_{}_R_Mean'.format(i), self.reward[i])
                for i in range(self.env.n)
            ]

        self.sess.run(tf.global_variables_initializer())
        self.train_writer = tf.summary.FileWriter(_CONFIG.log_dir,
                                                  graph=tf.get_default_graph())

        for i in range(self.env.n):
            self.actor[i].update()
            self.critic[i].update()
Пример #3
0
    def __init__(self):
        try:
            if getuid() == 0:
                raise RuntimeError('Running as root which is not safe nor required!')
        except AttributeError:
            if IsUserAnAdmin() != 0:
                raise RuntimeError('Running as administrator which is not safe nor required!')

        # core attributes
        self.config_reader = GeneralConfig()
        self.configuration = self.config_reader.read()
        self.msg_parser = parser.parser()
        self.registered = False
        self.first_msg = True
        self.channel_list = []
        self.module_controller = bot_module_functions.module_controller(self.msg_parser)
        self.ircchannel_controller = bot_module_functions.ircchannel_controller(self.channel_list,
                                                                                self.configuration,
                                                                                self.config_reader)

        # run core methods
        self.setup_logger()
        self.ircchannel_controller.populate_ircchannels()
Пример #4
0
 def setUp(self):
     self.config = GeneralConfig()
Пример #5
0
class TestGeneralConfig(TestBase):
    """
    A group of unittests to ensure correct functionality of GeneralConfig
    """
    def setUp(self):
        self.config = GeneralConfig()

    def test___init__(self):
        defaults = {
            'test': '1',
            'override_me':'wrong_value'
        }
        arg_dict = {
            'arg_dict':'some_value',
            'override_me':'expected_value'
        }
        general_config = GeneralConfig(defaults=defaults, **arg_dict)

        self.assertEqual(getattr(general_config, 'test'), '1')
        self.assertEqual(getattr(general_config, 'arg_dict'), 'some_value')
        self.assertEqual(getattr(general_config, 'override_me'), 'expected_value')

    def test___getattr__(self):
        self.config.test = 1
        self.assertTrue(self.config.test, 1)

    def test___setattr__(self):
        self.config.test = 1
        self.assertEqual(self.config.test, 1)

    def test___setattr___Not_Set_Sentinel(self):
        # The general config class should refuse
        # to set attributes to an instance of NotSetSentinel
        self.config.test = NotSetSentinel()
        self.assertEqual(self.config.test, None)

    def test_update(self):
        # The function should perform similarly to a dict's update method
        test_dict = {
            'cool_attribute': 'cool_value',
            'not_set': NotSetSentinel()
        }
        self.config.update(**test_dict)
        self.assertEquals(self.config.cool_attribute, 'cool_value')
        self.assertEquals(self.config.not_set, None)

    def test___getitem__(self):
        self.config.test = 1
        self.assertEqual(self.config['test'], 1)

    def test___setitem__(self):
        self.config['test'] = 1
        self.assertEqual(self.config['test'], 1)

    def test___delitem__(self):
        self.config['test'] = 1
        self.assertEqual(self.config['test'], 1)
        del self.config['test']
        self.assertRaises(KeyError, self.config.__getitem__, 'test')

    def test___contains__(self):
        self.config['test'] = 1
        self.assertTrue('test' in self.config)
        self.assertFalse('not_here' in self.config)
Пример #6
0
# Root directory of the project
ROOT_DIR = os.path.abspath("../../")
FROM_DATASET = False
# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

from bullet import BulletConfig, BulletDataset

from config import GeneralConfig

gconf = GeneralConfig()
DATASET_DIR, MODEL_DIR, OLD_MODEL_PATH, COCO_MODEL_PATH, WEIGHTS_FILE_PATH, EXP_DIR  \
        = gconf.DATASET_DIR, gconf.MODEL_DIR, gconf.OLD_MODEL_PATH, gconf.COCO_MODEL_PATH, gconf.WEIGHTS_FILE_PATH, gconf.EXP_DIR

from ipdb import set_trace

# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

# Device to load the neural network on.
# Useful if you're training a model on the same
# machine, in which case use CPU and leave the
# GPU for training.
# DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
Пример #7
0
        '-l',
        '--load',
        type=int,
        help="Indicates the step you wanna start, file must exist")

    args = parser.parse_args()
    start = 0

    scenario = scenarios.load(args.scenario).Scenario()
    world = scenario.make_world()

    env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
                        scenario.observation)
    policies = MultiAgent(env, "Multi-Agent")

    REPLAY_BUFFER = ReplayBuffer(env, GeneralConfig())

    if args.load is not None:
        start = args.load

        if args.dir is None:
            print("[!] Please indicate a path for model storing")
            exit(1)

        policies.load(args.dir, start)

    for i in range(start, start + args.n_round):
        play(i, env, policies)

        if (i + 1) % args.every == 0:
            policies.save(args.dir, i)
Пример #8
0
        # === Emulator ===
        for i in range(_config.STEP):
            action = agent.pick_action(obs)
            obs_next, reward, done, _ = env.step(action)

            # agent will store the newest experience into replay buffer, and training with mini-batch and off-policy
            agent.perceive(obs, action, reward, done)

            if done:
                break

            obs = obs_next

        # == train ==
        agent.train(episode)

        if (episode + 1) % agent.save_every == 0:
            agent.save(step=episode)

        # == test ==
        print("\n[*] === Enter TEST module ===")
        test(env, _config.STEP, agent)

    agent.record()


if __name__ == "__main__":
    config = GeneralConfig()
    main(config)
Пример #9
0
class bot(asynchat.async_chat):

    def __init__(self):
        try:
            if getuid() == 0:
                raise RuntimeError('Running as root which is not safe nor required!')
        except AttributeError:
            if IsUserAnAdmin() != 0:
                raise RuntimeError('Running as administrator which is not safe nor required!')

        # core attributes
        self.config_reader = GeneralConfig()
        self.configuration = self.config_reader.read()
        self.msg_parser = parser.parser()
        self.registered = False
        self.first_msg = True
        self.channel_list = []
        self.module_controller = bot_module_functions.module_controller(self.msg_parser)
        self.ircchannel_controller = bot_module_functions.ircchannel_controller(self.channel_list,
                                                                                self.configuration,
                                                                                self.config_reader)

        # run core methods
        self.setup_logger()
        self.ircchannel_controller.populate_ircchannels()

###
# ASYNCHAT / ASYNCORE OVERRIDES
###

    def collect_incoming_data(self, data):
        """Called when data is received from the server.
        Modifies the input buffer with the received data.

        """

        self.ibuffer = self.ibuffer + data

    def found_terminator(self):
        """Called when the terminator symbol is found.
        Sends input buffer to the self.msg_parser and calls send_data if
        a response is given by the self.msg_parser.

        """

        logging.debug('Received data: ' + self.ibuffer)
        if not self.registered and self.first_msg:
            self.register()

        elif self.ibuffer.startswith("PING"):
            #Received a ping, respond instantly
            self.push("PONG :" + self.ibuffer.split(':')[1] + "\r\n")
            logging.info('Responded to PING: ' +
                         self.ibuffer.split(':')[1])

        else:
            #Received something else
            response = self.msg_parser.input_line(self.ibuffer)
            for tuple_ in response:
                func = None
                if not tuple_ is None:
                    if hasattr(self.module_controller, tuple_[0]):
                        func = getattr(self.module_controller, tuple_[0])
                    elif hasattr(self.ircchannel_controller, tuple_[0]):
                        func = getattr(self.ircchannel_controller, tuple_[0])
                    elif hasattr(self, tuple_[0]):
                        func = getattr(self, tuple_[0])
                    else:
                        logging.warn('No attribute %s found' % tuple_[0])
                        continue

                if func and tuple_[1] == '':
                    func()
                elif func:
                    func(tuple_[1])
                else:
                    logging.info('Response from modules was: %r' % tuple_)

        self.ibuffer = ''

    def handle_close(self):
        """Called when the socket is closed."""

        logging.info('Saving channels')
        modules.shutdown(self.msg_parser)
        self.close()
        exit()

###
# CORE METHODS
###

    def connect_to_irc(self):
        """Connects to the server set in configuration file."""

        asynchat.async_chat.__init__(self)
        logging.debug('Asynchat initialized')
        self.set_terminator("\r\n")
        self.ibuffer = ''
        self.obuffer = deque([])
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        logging.debug('Socket initialized')
        self.connect((self.configuration['server'],
                      self.configuration['port']))
        logging.info('Connected to: ' + self.configuration['server'])
        modules.startup(self.msg_parser)

    def setup_logger(self):
        """Sets up our logger instance, varying if we want verbose
        or logfiles.

        """
        logging.basicConfig(
            level=logging.DEBUG,
            format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
            datefmt='%Y-%m-%d : %H:%M',
            filename='./framework.log',
            filemode='w')
        console = logging.StreamHandler()
        formatter = logging.Formatter(
            '%(name)-12s: %(levelname)-8s %(message)s')
        console.setFormatter(formatter)

        if self.configuration['verbose']:
            console.setLevel(logging.INFO)

        else:
            console.setLevel(logging.WARN)

        logging.getLogger('').addHandler(console)

    def register(self):
        """Registers the bot to the server"""

        logging.info('Registering ...')
        self.set_nick(self.configuration['nick'])
        self.push("USER " + self.configuration['user'] +
                  " 0 * :" + self.configuration['real'] + '\r\n')
        self.first_msg = False
        for chan in self.channel_list:
            self.send_data("JOIN :" + chan)

    def set_registered(self, value):
        """Sets bot to registered.
        Normally called from Exceptions.py

        Keyword arguments:
            value -- boolean; self.registered value

        """

        self.registered = value
        self.empty_queue()

    def set_nick(self, nick):
        """Sends a 'NICK' message to the server

        Keyword arguments:
            nick -- string; our new nick

        """

        logging.info('Setting nick to: ' + nick)
        self.config_reader.write('irc', 'nickname', nick)
        self.push("NICK " + nick + '\r\n')

    def send_data(self, data):
        """Pushes the data upstream

        Keyword arguments:
            data -- data to send

        """

        if not self.registered:
            self.obuffer.append(data + '\r\n')
            logging.info('Queued for sending: ' + data)
        else:
            self.push(data + '\r\n')
            logging.info("Sent to the server: " + data)

    def empty_queue(self):
        """Pushes all the data saved in obuffer to the server
        in chronological order.

        """

        if self.registered:
            while len(self.obuffer) != 0:
                _data = self.obuffer.popleft()
                self.push(_data)
                logging.info('Popped from queue: ' + _data.strip('\r\n'))
        else:
            logging.warn('Tried emptying queue whilst not registered')

    def reconnect(self):
        """Reconnects to the server and reloads configuration,
        modules and self.msg_parser.

        """

        self.send_data("QUIT :Reconnecting")
        self.close()
        self.config_reader = GeneralConfig()
        self.configuration = self.config_reader.read()
        self.msg_parser = parser.parser()
        self.registered = False
        self.first_msg = True
        self.channel_list = []
        bot_module_functions.populate_ircchannels(self)
        self.connect_to_irc()

    def disconnect(self):
        """Disconnects from the IRC server."""

        self.send_data("QUIT :Disconnecting")
        modules.shutdown(self.msg_parser)
        self.close()