Пример #1
0
 def __init__(self, agent, env=None):
     self.lock = th.Lock()
     self.agent = agent
     # one and only
     self.para_list = self.get_parameter_list()
     self.farmer = farmer_class(self.para_list)
     self.ep_num = 0
     self.total_steps = 0
     self.history_reward = []
     self.ep_value = {}
     self.value_init()
     self.relative_time = 0
     self.average_len_of_episode = self.agent.args.max_pathlength
     self.num_rollouts = int(self.agent.args.timesteps_per_batch /
                             self.average_len_of_episode)
     self.rollout_count = 0
     self.rollout_paths = []
     self.iteration = 0
     self.log_scalar_name_list = [
         'reward', 'kl_div', 'entropy', 'surrogate_loss', 'value_loss'
     ]
     self.log_scalar_type_list = [
         tf.float32, tf.float32, tf.float32, tf.float32, tf.float32
     ]
     self.logger = Logger(self.agent.session,
                          self.agent.args.log_path + 'train',
                          self.log_scalar_name_list,
                          self.log_scalar_type_list)
     self.write_log = self.logger.create_scalar_log_method()
     self.start_time = time.time()
Пример #2
0
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

        # For tst
        if args.tst_free:
            self.trlog['max_tst_criterion'] = 0.0
            self.trlog['max_tst_criterion_interval'] = 0.
            self.trlog['max_tst_criterion_epoch'] = 0
            self.trlog['tst_criterion'] = args.tst_criterion
Пример #3
0
    def __init__(self, args):
        if args.dataset == 'CUB':
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)]
        else:
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)]
        if args.eval_dataset == 'CUB':
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)]
        else:
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)]
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0
 def __init__(self, name: str, conf: Dict):
     self.plc_name = name
     self.conf = conf
     self.modbus_port = conf['modbus_port']
     self.worker_processes = {}
     self.setup_complete = False
     self.logger = Logger("PLCLogger", "../logger/logs/plc_log.txt", prefix="[{}]".format(self.plc_name))
     self.clock = PLCClock()
Пример #5
0
 def __init__(self, config_path, send_port=5000):
     self.controller_ps = []
     self.config = self.read_config(config_path)
     self.selector = selectors.DefaultSelector()
     self.publish_queue = PublishQueue()
     self.udp_send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
     self.udp_send_socket.bind(('', send_port))
     self.logger = Logger('InterfaceLogger',
                          '../logger/logs/interface_log.txt')
Пример #6
0
def init_extarct():
    global trainer
    global logger
    cfg = ConfigParser()
    configuration_path = Path(__file__).resolve(
        strict=True).parent / 'configs' / 'extract_eval.conf'
    cfg.read(configuration_path)
    logger = Logger(cfg)
    logger.info(f'Configuration parsed: {cfg.sections()}')
    trainer = SpanTrainer(cfg, logger)
Пример #7
0
    def __init__(self, remote='127.0.0.1', port=5050, player: Player = None):
        self.remote = remote
        self.port = port
        self.address = (remote, port)
        self.player = player
        self.app_logger = L.create_logger('CLIENT_MAIN',
                                          logging.INFO,
                                          to_file=True)
        self.communication_logger = L.create_logger('Hangman client',
                                                    logging.INFO)

        self.app_logger.info(f'Starting client... src={SRC}/__init__:41')
        self.communication_logger.info('Welcome to the hangman client!')
Пример #8
0
 def __init__(self, host='', port=5050, max_connections=5):
     """
     Initialise the server.
     :param host: the host's IPv4 address to bind the server to.
     :param port: the port to bind the server to. Default = 5050.
     :param max_connections: the maximum number of connections the server will queue before dropping the next.
     """
     self.logger_file = L.create_logger('SERVER_MAIN', logging.INFO, to_file=True)
     self.logger_info = L.create_logger('SERVER_MAIN_COMM', logging.INFO, to_file=False)
     self.host = host
     self.port = port
     self.address = (host, port)
     self.groups = []
     self.events = []
     self.max_connections = max_connections
Пример #9
0
    def __init__(self, args):
        self.args = args
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_auc'] = 0.0
        self.trlog['max_auc_epoch'] = 0
        self.trlog['max_auc_interval'] = 0.0
Пример #10
0
 def __init__(self, conf: Dict):
     self.sensor_bus = SensorBus(conf)
     self.worker_info = {}
     self.logger = Logger("ActorLogs",
                          "../model/logger/logs/actors_logs.txt")
     self._init_worker_info(conf)
     self.control_graph = ControlGraph()
Пример #11
0
 def __init__(self, server_socket: socket, events: list):
     super().__init__()
     self.socket = server_socket
     self.events = events
     self.logger = Logger.create_logger('SERVER_INPUT',
                                        logging.INFO,
                                        to_file=True)
Пример #12
0
 def __init__(self,
              port,
              localhost=True,
              device_function_codes=None,
              socket_type=socket.SOCK_STREAM,
              failures={}):
     self.port = port
     self.localhost = localhost
     self.logger = Logger('ServerLogger-{}'.format(port),
                          '../logger/logs/server_log.txt',
                          prefix='Server {}'.format(port))
     self.stop = threading.Event()
     self.done = threading.Event()
     self.device_function_codes = device_function_codes
     self._current_connection = None
     self.socket_type = socket_type
     self.failures = failures
     self.lock = threading.RLock()
Пример #13
0
    def __init__(self, args):
        self.args = args
        ensure_path(
            self.args.save_path,
            scripts_to_save=['model/models', 'model/networks', __file__],
        )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0
Пример #14
0
 def __init__(self, attr: Dict, pipe: int):
     self.pipe = pipe
     self.lock = RLock()
     self.attributes = attr
     self.started = False
     self.modbus_receiver = None
     self.modbus_thread = None
     self.logger = Logger('WorkerLogger-{}'.format(attr.get('port', 0)), '../logger/logs/worker_log.txt',
                          prefix='Worker Server {}'.format(attr.get('port', 0)))
     self.previous_readings = collections.deque(maxlen=1000)
     self.num_readings = 0
Пример #15
0
 def __init__(self, agent, env=None):
     self.lock = th.Lock()
     self.agent = agent
     # one and only
     self.farmer = farmer_class(self.agent.para_list)
     self.ep_num = 0
     self.total_steps = 0
     self.history_reward = []
     self.ep_value = {}
     self.value_init()
     self.relative_time = 0
     self.average_steps = self.agent.para_list["max_pathlength"]
     self.log_scalar_name_list = [
         'mean_reward', 'actor_loss', 'critic_loss'
     ]
     self.log_scalar_type_list = [tf.float32, tf.float32, tf.float32]
     self.logger = Logger(self.agent.session,
                          self.agent.para_list["log_path"] + 'train',
                          self.log_scalar_name_list,
                          self.log_scalar_type_list)
     self.write_log = self.logger.create_scalar_log_method()
     self.start_time = time.time()
Пример #16
0
    def __init__(self, assigned_groups: list, lock: Lock, add_event_to_server):
        """
        Initialise the group manager.
        :param assigned_groups: the groups this manager will manage.
        :param lock: a thread lock to ensure the Game thread does not access its group at the same time as this manager.
        :param add_event_to_server: a reference to the corresponding server's add_terminal_event method.
        """
        super().__init__()
        self.terminal_event = Event()
        self.groups = assigned_groups
        self.logger = Logger.create_logger(self.__repr__(),
                                           logging.INFO,
                                           to_file=True)
        self.games = []
        self.lock = lock

        add_event_to_server(self.terminal_event)
Пример #17
0
 def __init__(self, player: Player, lock: Lock, groups: list):
     super().__init__()
     self.player = player
     self.logger = Logger.create_logger(self.__repr__(), logging.INFO, to_file=True)
     self.lock = lock
     self.groups = groups
Пример #18
0
class SimulinkInterface:
    def __init__(self, config_path, send_port=5000):
        self.controller_ps = []
        self.config = self.read_config(config_path)
        self.selector = selectors.DefaultSelector()
        self.publish_queue = PublishQueue()
        self.udp_send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.udp_send_socket.bind(('', send_port))
        self.logger = Logger('InterfaceLogger',
                             '../logger/logs/interface_log.txt')
        # Used for sim time oracle which is currently disabled
        # self.time_oracle = None

    def read_config(self, config_path):
        """
            Read and parse the .yaml configuration file.
        :param config_path: the path to the specific yaml file
        :return:
        """
        with open(config_path, 'r') as stream:
            try:
                config_yaml = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)
                exit(1)
        # if not config_yaml['settings']:
        #     print("Failed to read config... no settings in file")
        #     exit(1)
        # if not config_yaml['settings']['send_port']:
        #     print("Failed to read config... no send_port in settings")
        #     exit(1)
        return config_yaml

    def create_plcs(self):
        """
            Creates PLC(s) based on the .yaml configuration file.
            Each PLC spawns several "worker" processes which listen for incoming simulink data
                through a predefined publish / subscribe mechanism.
            The register_workers function creates the server sockets for each worker, and registers them to
            the main selector. These workers also subscribe to receive the data coming to their "listening port"
            by registering themselves with the publish queue.
        :return:
        """
        for plc_name, plc_config in self.config.items():
            if plc_name == 'time_oracle':
                continue
            controller = LogicController(plc_name, plc_config)
            controller.register_workers(self.selector, self.publish_queue)
            controller_ps = multiprocessing.Process(
                target=controller.start_plc, daemon=True, name=plc_name)
            self.controller_ps.append(controller_ps)

    # def init_time_oracle(self):
    #     """
    #     Read simtime oracle source for more information, tries to resolve the difference between real
    #         and simulated time using a pseudo-NTP approach. You can work on this if you'd like to centralize the timer
    #         to the interface instead of using a virtual PLC (oracle PLC) to manage simulation timestamps.
    #     :return:
    #     """
    #     timer_conf = self.config['time_oracle']
    #     time_oracle = simtimeoracle.SimulationTimeOracle(timer_conf['receive_port'], timer_conf['respond_port'])
    #     return time_oracle, multiprocessing.Process(target=time_oracle.start, name='Time Oracle', daemon=True)

    # def _accept_connection(self, sock: socket.socket):
    #     """
    #         !!!! NO LONGER USED IN UDP Implementation !!!!
    #         Upon receiving a new connection from simulink register this connection to our selector and
    #         it's respective data.
    #     :param sock:
    #     :return:
    #     """
    #     conn, addr = sock.accept()
    #     print("New connection from {}".format(addr))
    #     conn.setblocking(False)
    #     self.selector.register(conn, selectors.EVENT_READ, {"connection_type": "client_connection",
    #                                                         "channel": addr[1]})

    def _read_and_publish(self, connection: socket, channel: str):
        """
            Reads data from simulink.
            |----------------------|
            |--- 64 bit timestamp -|
            |--- 64 bit reading ---|
            |----------------------|
        :param connection: the connection from a simulink block
        :param channel: the channel to publish this data to on the publish queue
        """
        data = connection.recv(16)  # Should be ready
        if data:
            sim_time, reading = struct.unpack(">dd", data)
            sim_time = int(sim_time)
            self.publish_queue.publish((sim_time, reading), channel)
        else:
            print('closing', connection)
            self.selector.unregister(connection)
            connection.close()

    def _send_response(self, read_pipe, host: str, port: int):
        """
            Reads from the worker pipe and forwards the data to the respective simulink block
            based on the host and port specified.
        :param read_pipe: a pipe connecting the worker thread to the main simulink selector
        :param host: ip / hostname to send data
        :param port: port number that the host is listening on
        :return:
        """
        response_data = os.read(read_pipe, 128)
        self.logger.info("Sending response {} to {}:{}".format(
            binascii.hexlify(response_data), host, port))
        self.udp_send_socket.sendto(response_data, (host, port))

    def service_connection(self, key):
        """
            Based on the information in the key['connection_type'] route take the correct action.
            For server_sockets read the data and publish to the queue.
            For responses read the appropriate data from the response pipe and forward to simulink.
        :param key: The key associated with the file object registered in the selector
        :return:
        """
        connection = key.fileobj
        connection_type = key.data['connection_type']
        if connection_type == 'server_socket':
            channel = key.data['channel']
            self._read_and_publish(connection, channel)
        if connection_type == 'response':
            read_pipe = key.fileobj
            # The address to respond to should be registered along with the pipe object
            host, port = key.data['respond_to']
            self._send_response(read_pipe, host, port)

    def start_server(self):
        """
            Set up the virtual PLC(s) and their respective worker processes / threads.
            Initialize the time oracle.
            Once setup, start the PLC(s) to begin listening for data.
            Then start the selector loop, waiting for new data and servicing incoming responses.
        :return:
        """
        # Time oracle stuff is now manage in PLC sensors
        # self.time_oracle, time_oracle_ps = self.init_time_oracle()
        # time_oracle_ps.start()

        self.create_plcs()
        for plc in self.controller_ps:
            self.logger.info('Starting controller: {}'.format(plc))
            plc.start()

        while True:
            events = self.selector.select()
            for key, mask in events:
                self.service_connection(key)
Пример #19
0
 def __init__(self, conf: Dict):
     self.sensor_bus = SensorBus(conf)
     self.logger = Logger("ActorLogs",
                          "../model/logger/logs/actors_logs.txt")
Пример #20
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        if args.dataset == 'CUB':
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)]
        else:
            self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)]
        if args.eval_dataset == 'CUB':
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)]
        else:
            self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)]
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:
            if args.eval_all:
                for i, (args.eval_way, args.eval_shot) in enumerate(self.VAL_SETTING):
                    if i == 0:
                        vl, va, vap = self.eval_process(args, epoch)
                    else:
                        self.eval_process(args, epoch)
            else:
                vl, va, vap = self.eval_process(args, epoch)
            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')
            print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
                self.trlog['max_acc_epoch'],
                self.trlog['max_acc'],
                self.trlog['max_acc_interval']))

    def eval_process(self, args, epoch):
        valset = self.valset
        if args.model_class in ['QsimProtoNet', 'QsimMatchNet']:
            val_sampler = NegativeSampler(args, valset.label,
                                          args.num_eval_episodes,
                                          args.eval_way, args.eval_shot + args.eval_query)
        else:
            val_sampler = CategoriesSampler(valset.label,
                                            args.num_eval_episodes,
                                            args.eval_way, args.eval_shot + args.eval_query)
        val_loader = DataLoader(dataset=valset,
                                batch_sampler=val_sampler,
                                num_workers=args.num_workers,
                                pin_memory=True)
        vl, va, vap = self.evaluate(val_loader)
        self.logger.add_scalar('%dw%ds_val_loss' % (args.eval_way, args.eval_shot), float(vl),
                               self.train_epoch)
        self.logger.add_scalar('%dw%ds_val_acc' % (args.eval_way, args.eval_shot), float(va),
                               self.train_epoch)
        print('epoch {},{} way {} shot, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, args.eval_way,
                                                                                   args.eval_shot, vl, va,
                                                                                   vap))
        return vl, va, vap

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                  .format(self.train_epoch,
                          self.train_step,
                          self.max_steps,
                          tl1.item(), tl2.item(), ta.item(),
                          self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, ' \
                  'forward_timer: {:.2f} sec,' \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                self.dt.item(), self.ft.item(),
                self.bt.item(), self.ot.item())
            )
            self.logger.dump()

    def save_model(self, name):
        torch.save(
            dict(params=self.model.state_dict()),
            osp.join(self.args.save_path, name + '.pth')
        )

    def __str__(self):
        return "{}({})".format(
            self.__class__.__name__,
            self.model.__class__.__name__
        )
Пример #21
0
from view.view import View
from controller.controller import Controller
from model.algorithm import Algorithm
from model.state import State
from model.logger import Logger
from model.alert import Alert

if __name__ == "__main__":

    app = QtWidgets.QApplication(sys.argv)
    MainWindow = QtWidgets.QMainWindow()

    algorithm = Algorithm()
    state = State()
    logger = Logger()
    alert = Alert()

    controller = Controller()
    controller.connectToModels(algorithm=algorithm,
                               state=state,
                               logger=logger,
                               alert=alert)

    ui = View(MainWindow)
    ui.connectToController(controller)
    ui.subscribeToModels(algorithm, state, logger, alert)

    # controller = SearchSortUIController(ui)

    MainWindow.show()
class LogicController:
    def __init__(self, name: str, conf: Dict):
        self.plc_name = name
        self.conf = conf
        self.modbus_port = conf['modbus_port']
        self.worker_processes = {}
        self.setup_complete = False
        self.logger = Logger("PLCLogger",
                             "../logger/logs/plc_log.txt",
                             prefix="[{}]".format(self.plc_name))
        self.clock = PLCClock()
        self.register_map = {}

    def __str__(self):
        return "{}:\n{}".format(self.plc_name, self.conf)

    def start_plc(self, modbus_port=None):
        if self.setup_complete:
            self.start_workers()
            self.start_modbus_server(modbus_port)
        else:
            self.logger.warning(
                "PLC has not been initialized, rejecting start up")

    def register_workers(self, selector, publish_queue):
        workers_conf = self.conf['workers']
        for worker_name, attr in workers_conf.items():
            # Invoke the factory to create a new worker
            attr['name'] = worker_name
            worker, response_pipe_r = WorkerFactory.create_new_worker(attr)
            if worker is None:
                continue
            # Add the clock to the workers attributes
            attr['clock'] = self.clock
            # If this worker intends to respond to simulink then
            # Link up it's pipe to the main selector
            if response_pipe_r:
                respond_to = (attr['respond_to']['host'],
                              attr['respond_to']['port'])
                selector.register(response_pipe_r, selectors.EVENT_READ, {
                    "connection_type": "response",
                    "respond_to": respond_to
                })
            # If this worker intends to listen from simulink data then it should give a port
            # A server socket will be set up for this port in the main selector
            # Data destined to this port will be parsed, packaged, and then sent to listening worker processes
            # using the publish_queue
            port = 0
            if attr.get('port', None):
                port = attr['port']
                serverfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                serverfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                serverfd.bind(('', attr['port']))
                selector.register(serverfd, selectors.EVENT_READ, {
                    "connection_type": "server_socket",
                    "channel": attr['port']
                })
                # Unsure whether creation of thread or starting a thread attaches it to the parent ps
                # If there are performance issues in the simulink interface you can investigate this
            channel = attr.get('channel', port)
            p = threading.Thread(target=worker.run,
                                 args=(publish_queue.register(channel), ))
            self.worker_processes[worker_name] = {
                "process": p,
                "attributes": attr,
                "worker": worker,
            }
            self.register_map[int(attr['register'])] = worker
            self.logger.info("Setting up worker '{}'".format(worker_name))
        self.setup_complete = True

    def start_workers(self):
        for worker_name, info in self.worker_processes.items():
            self.logger.info("Starting up worker '{}'".format(worker_name))
            info['process'].start()

    def start_modbus_server(self, port=None):
        if port is None:
            port = self.conf['modbus_port']

        ENDIANNESS = 'BIG'

        def handle_request(request):
            request_header = request['header']
            request_body = request['body']
            self.logger.debug(
                "Servicing modbus request {}".format(request_header))
            start_register = request_body['address']
            if request_header[
                    'function_code'] == FunctionCodes.WRITE_SINGLE_HOLDING_REGISTER:
                setting = request_body['value']
                worker = self.register_map.get(start_register, None)
                if worker:
                    if hasattr(worker, 'set_reading'):
                        worker.set_reading(setting)
                        self.logger.info(
                            "Setting new pressure reading to {} at {}".format(
                                setting, worker.attributes['name']))
                return modbusencoder.respond_write_registers(
                    request_header, 0, 1, endianness=ENDIANNESS)
            else:
                readings = []
                register_count = request_body['count']
                for current_reg in range(start_register, register_count, 2):
                    worker = self.register_map.get(current_reg, None)
                    if worker:
                        self.logger.info('Retrieving data from {}'.format(
                            worker.attributes['name']))
                        readings.append((worker.get_reading(), 'FLOAT32'))
                self.logger.info(
                    "Responding to request with {}".format(readings))
                return modbusencoder.respond_read_registers(
                    request_header, readings, endianness=ENDIANNESS)

        DEVICE_FUNCTION_CODES = [3, 4, 6, 16]
        modbus_receiver = ModbusReceiver(
            port,
            device_function_codes=DEVICE_FUNCTION_CODES,
            socket_type=socket.SOCK_DGRAM)
        self.logger.info("Starting modbus server for PLC on {}".format(
            self.modbus_port))
        modbus_receiver.start_server(handle_request)
Пример #23
0
class parallel_rollouts():
    def __init__(self, agent, env=None):
        self.lock = th.Lock()
        self.agent = agent
        # one and only
        self.para_list = self.get_parameter_list()
        self.farmer = farmer_class(self.para_list)
        self.ep_num = 0
        self.total_steps = 0
        self.history_reward = []
        self.ep_value = {}
        self.value_init()
        self.relative_time = 0
        self.average_len_of_episode = self.agent.args.max_pathlength
        self.num_rollouts = int(self.agent.args.timesteps_per_batch /
                                self.average_len_of_episode)
        self.rollout_count = 0
        self.rollout_paths = []
        self.iteration = 0
        self.log_scalar_name_list = [
            'reward', 'kl_div', 'entropy', 'surrogate_loss', 'value_loss'
        ]
        self.log_scalar_type_list = [
            tf.float32, tf.float32, tf.float32, tf.float32, tf.float32
        ]
        self.logger = Logger(self.agent.session,
                             self.agent.args.log_path + 'train',
                             self.log_scalar_name_list,
                             self.log_scalar_type_list)
        self.write_log = self.logger.create_scalar_log_method()
        self.start_time = time.time()

    def get_parameter_list(self):
        para_list = {
            "model": self.agent.args.model,
            "task": self.agent.args.task,
            "policy_layer_norm": self.agent.args.policy_layer_norm,
            "value_layer_norm": self.agent.args.value_layer_norm,
            "policy_act_fn": self.agent.args.policy_act_fn,
            "value_act_fn": self.agent.args.value_act_fn,
            "max_pathlength": self.agent.args.max_pathlength,
            "farmer_port": self.agent.args.farmer_port,
            "farm_list_base": self.agent.args.farm_list_base,
            "farmer_debug_print": self.agent.args.farmer_debug_print,
            "farm_debug_print": self.agent.args.farm_debug_print
        }
        return para_list

    def refarm(self):  # most time no use
        del self.farmer
        self.farmer = farmer_class()

    def value_init(self):
        # self.ep_value['actor_loss'], self.ep_value['critic_loss'], self.ep_value['fit_time'] = 0, 0, 0
        # self.ep_value['sample_time'], self.ep_value['actor_train_time'] = 0, 0
        # self.ep_value['critic_train_time'], self.ep_value['target_update_time'] = 0, 0
        pass

    def fit(self):
        # calculate real leaning rate with a linear schedule
        final_learning_rate = 5e-5
        policy_delta_rate = (self.agent.args.policy_learning_rate -
                             final_learning_rate) / self.agent.args.n_steps
        value_delta_rate = (self.agent.args.value_learning_rate -
                            final_learning_rate) / self.agent.args.n_steps
        policy_learning_rate = self.agent.args.policy_learning_rate - policy_delta_rate * self.total_steps
        value_learning_rate = self.agent.args.policy_learning_rate - value_delta_rate * self.total_steps

        self.ep_value["learning_rate"] = policy_learning_rate

        # vars: update_old_policy_time, train_time, fit_time, surrogate_loss, kl_after, entropy_after, value_loss
        vars = self.agent.fit(self.rollout_paths,
                              [policy_learning_rate, value_learning_rate])

        # update print info:
        self.ep_value["episode_nums"] = self.num_rollouts
        self.ep_value["episode_steps"] = sum(
            [len(path["rewards"]) for path in self.rollout_paths])
        self.ep_value["episode_time"] = sum(
            [path["ep_time"]
             for path in self.rollout_paths]) / self.num_rollouts
        self.ep_value["episode_reward"] = sum(
            [path["rewards"].sum()
             for path in self.rollout_paths]) / self.num_rollouts
        self.ep_value["average_steps"] = int(self.average_len_of_episode)
        self.ep_value["step_time"] = sum(
            [path["step_time"]
             for path in self.rollout_paths]) / self.num_rollouts
        self.ep_value['update_time'], self.ep_value[
            'train_time'], self.ep_value['iter_time'] = vars[0], vars[1], vars[
                2]
        self.ep_value['surrogate_loss'], self.ep_value['kl_after'] = vars[
            3], vars[4]
        self.ep_value['entropy_after'], self.ep_value['value_loss'] = vars[
            5], vars[6]
        self.relative_time = (time.time() - self.start_time) / 60

        # write_log with tensorboard: ['reward', 'kl_div', 'entropy', 'surrogate_loss', 'value_loss']
        self.write_log([
            self.ep_value['episode_reward'], self.ep_value['kl_after'],
            self.ep_value['entropy_after'], self.ep_value['surrogate_loss'],
            self.ep_value['value_loss']
        ], self.iteration)

        # print iteration information:
        episode_print(self.iteration, self.total_steps, self.relative_time,
                      self.ep_value)
        self.history_reward.append(self.ep_value["episode_reward"])

    def rollout_an_episode(self, env):
        global graph
        with graph.as_default():
            path = self.agent.run_an_episode(env)
        self.lock.acquire()
        self.ep_num += 1
        self.rollout_count += 1
        self.rollout_paths.append(path)
        self.total_steps += len(path["actions"])
        self.lock.release()
        env.rel()

    def rollout_if_available(self):
        while True:
            remote_env = self.farmer.acq_env()  # call for a remote_env
            if remote_env is False:  # no free environment
                # time.sleep(0.1)
                pass
            else:
                t = th.Thread(target=self.rollout_an_episode,
                              args=(remote_env, ),
                              daemon=True)
                t.start()
                break

    def rollout(self):
        while self.total_steps < self.agent.args.n_steps:
            if self.num_rollouts == 0:
                raise RuntimeError('wrong, div 0!!!')
            self.iteration += 1
            for i in range(self.num_rollouts):
                self.rollout_if_available()

            while self.rollout_count != self.num_rollouts:
                pass

            if (self.iteration + 1) % 10 == 0:
                self.agent.saveModel(0)

            self.average_len_of_episode = sum(
                [len(path["rewards"])
                 for path in self.rollout_paths]) / self.num_rollouts
            if self.average_len_of_episode == 0:
                raise RuntimeError('wrong, div 0!!!')
            self.fit()
            self.rollout_count = 0
            self.rollout_paths = []
            self.num_rollouts = int(self.agent.args.timesteps_per_batch /
                                    self.average_len_of_episode)

        self.agent.saveModel(0)
        return self.history_reward
Пример #24
0
class off_policy_parallel_rollouts():
    def __init__(self, agent, env=None):
        self.lock = th.Lock()
        self.agent = agent
        # one and only
        self.farmer = farmer_class(self.agent.para_list)
        self.ep_num = 0
        self.total_steps = 0
        self.history_reward = []
        self.ep_value = {}
        self.value_init()
        self.relative_time = 0
        self.average_steps = self.agent.para_list["max_pathlength"]
        self.log_scalar_name_list = [
            'mean_reward', 'actor_loss', 'critic_loss'
        ]
        self.log_scalar_type_list = [tf.float32, tf.float32, tf.float32]
        self.logger = Logger(self.agent.session,
                             self.agent.para_list["log_path"] + 'train',
                             self.log_scalar_name_list,
                             self.log_scalar_type_list)
        self.write_log = self.logger.create_scalar_log_method()
        self.start_time = time.time()

    def refarm(self):  # most time no use
        del self.farmer
        self.farmer = farmer_class()

    def value_init(self):
        self.ep_value['actor_loss'], self.ep_value[
            'critic_loss'], self.ep_value['fit_time'] = 0, 0, 0
        self.ep_value['sample_time'], self.ep_value['actor_train_time'] = 0, 0
        self.ep_value['critic_train_time'], self.ep_value[
            'target_update_time'] = 0, 0

    def fit(self):
        vars = [0] * 7
        for i in range(self.average_steps):
            var = self.agent.fit()
            for i, v in enumerate(var):
                vars[i] += v / self.average_steps
        self.ep_value['actor_loss'], self.ep_value[
            'critic_loss'], self.ep_value['fit_time'] = vars[0], vars[1], vars[
                2]
        self.ep_value['sample_time'], self.ep_value['actor_train_time'] = vars[
            3], vars[4]
        self.ep_value['critic_train_time'], self.ep_value[
            'target_update_time'] = vars[5], vars[6]

    def process_path(self, path):
        ep_step = len(path["rewards"])
        ep_time = path["ep_time"]
        step_time = path["step_time"]
        ep_reward = sum(path["rewards"])
        ep_memory = []
        for i in range(ep_step):
            t = [
                path["obs"][i], path["actions"][i], path["rewards"][i],
                path["dones"][i], path["obs"][i + 1]
            ]
            ep_memory.append(t)
        return ep_step, ep_reward, ep_time, step_time, ep_memory

    def rollout_an_episode(self, noise_level,
                           env):  # this function is tread safe
        global graph
        with graph.as_default():
            path = rollout_an_episode(self.agent, env, self.lock, noise_level)
            self.fit()

        ep_step, ep_reward, ep_time, step_time, ep_memory = self.process_path(
            path)
        for t in ep_memory:
            self.agent.feed_one(t)

        self.lock.acquire()
        self.ep_num += 1
        self.total_steps += ep_step  # self.ep_value["episode_steps"]
        self.average_steps = int(self.total_steps / self.ep_num)
        self.ep_value["average_steps"] = self.average_steps
        self.ep_value['episode_steps'], self.ep_value[
            'episode_reward'] = ep_step, ep_reward
        self.ep_value['episode_time'], self.ep_value[
            'step_time'] = ep_time, step_time
        self.history_reward.append(
            ep_reward)  # (self.ep_value["episode_reward"])
        self.relative_time = (time.time() - self.start_time) / 60.0
        off_policy_episode_print(self.ep_num, self.total_steps,
                                 self.relative_time, self.ep_value)
        self.write_log([
            ep_reward, self.ep_value["actor_loss"],
            self.ep_value["critic_loss"]
        ], self.ep_num)
        self.lock.release()
        env.rel()

    def rollout_if_available(self, noise_level):
        while True:
            remote_env = self.farmer.acq_env()  # call for a remote_env
            if remote_env is False:  # no free environment
                # time.sleep(0.1)
                pass
            else:
                t = th.Thread(target=self.rollout_an_episode,
                              args=(noise_level, remote_env),
                              daemon=True)
                t.start()
                break

    def rollout(self):
        for i in range(self.agent.para_list["n_episodes"]):
            nl = noise_level_schedule(self.agent.para_list, self.total_steps)
            self.rollout_if_available(nl)
            if (i + 1) % 100 == 0:
                self.agent.saveModel(0)
            if self.total_steps > self.agent.para_list["n_steps"]:
                break

        self.agent.saveModel(0)
        return self.history_reward
Пример #25
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:
            vl, va, vap = self.evaluate(self.val_loader)
            self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
            self.logger.add_scalar('val_acc', float(va), self.train_epoch)
            print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                epoch, vl, va, vap))

            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print(
                'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                .format(self.train_epoch, self.train_step, self.max_steps,
                        tl1.item(), tl2.item(), ta.item(),
                        self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(),
                                   self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, '     \
                  'forward_timer: {:.2f} sec,'   \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                        self.dt.item(), self.ft.item(),
                        self.bt.item(), self.ot.item())
                  )
            self.logger.dump()

    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def __str__(self):
        return "{}({})".format(self.__class__.__name__,
                               self.model.__class__.__name__)
Пример #26
0
class ModbusReceiver:
    def __init__(self,
                 port,
                 localhost=True,
                 device_function_codes=None,
                 socket_type=socket.SOCK_STREAM,
                 failures={}):
        self.port = port
        self.localhost = localhost
        self.logger = Logger('ServerLogger-{}'.format(port),
                             '../logger/logs/server_log.txt',
                             prefix='Server {}'.format(port))
        self.stop = threading.Event()
        self.done = threading.Event()
        self.device_function_codes = device_function_codes
        self._current_connection = None
        self.socket_type = socket_type
        self.failures = failures
        self.lock = threading.RLock()

    '''
        Dispatches packet data for decoding based on it's function code.
        ModbusDecoder handles decoding of the packets and returns a Dict containing
        appropriate data. Invalid function codes lead to an invalid_function_code message which
        is also created by the modbus decoder.
    '''

    def _dissect_packet(self, packet_data) -> Dict:
        function_code = packet_data[0]
        # Check that the device supports this function code
        if self.device_function_codes:
            if function_code not in self.device_function_codes:
                return modbusdecoder.invalid_function_code(packet_data)
        switch = {
            1: modbusdecoder.read_coils,
            2: modbusdecoder.read_discrete_inputs,
            3: modbusdecoder.read_holding_registers,
            4: modbusdecoder.read_input_registers,
            5: modbusdecoder.write_single_coil,
            6: modbusdecoder.write_single_holding_register,
            15: modbusdecoder.write_multiple_coils,
            16: modbusdecoder.write_multiple_holding_registers
        }
        function = switch.get(function_code,
                              modbusdecoder.invalid_function_code)
        return function(packet_data)

    def _start_server_tcp(self, request_handler: Callable) -> None:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if self.localhost:
                s.bind(('localhost', self.port))
            else:
                s.bind((socket.gethostname(), self.port))
            s.listen(5)
            self.logger.info('Server started {}:{}'.format(
                socket.gethostname(), self.port))
            while not self.stop.is_set():
                if self.failures.get('disconnected', False):
                    sleep(1)
                    continue
                self._current_connection, address = s.accept()
                self.logger.info('New connection accepted {}'.format(
                    self._current_connection.getpeername()))
                with self._current_connection:
                    while not self.stop.is_set():
                        try:
                            buffer = self._current_connection.recv(7)
                            if buffer == b'' or len(buffer) <= 0:
                                self.logger.debug(
                                    'Initial read was empty, peer connection was likely closed'
                                )
                                break
                            header = buffer
                            self.logger.debug(
                                'MB:{} Header DATA like: {}'.format(
                                    self.port, header))
                            # Modbus length is in bytes 4 & 5 of the header according to spec (pg 25)
                            # https://www.prosoft-technology.com/kb/assets/intro_modbustcp.pdf
                            header = modbusdecoder.dissect_header(header)
                            length = header['length']
                            if length == 0:
                                self.logger.debug(
                                    'A length 0 header was read, closing connection'
                                )
                                break
                            data = self._current_connection.recv(length - 1)
                            StatisticsCollector.increment_packets_received()
                            response_start = time.time()
                            is_error, dissection = self._dissect_packet(data)
                            if is_error:
                                self.logger.debug(
                                    'MB:{} Header appears like: {}'.format(
                                        self.port, header))
                                self.logger.debug('MB:{} Request: {}'.format(
                                    self.port, hexlify(buffer + data)))
                                self.logger.debug(
                                    'MB:{} An error was found in the modbus request {}'
                                    .format(self.port, hexlify(dissection)))
                                self._current_connection.sendall(dissection)
                                response_stop = time.time()
                                StatisticsCollector.increment_error_packets_sent(
                                )
                                StatisticsCollector.increment_responses_sent()
                                StatisticsCollector.increment_avg_response(
                                    response_stop - response_start)
                                continue
                            else:
                                dissection['type'] = 'request'
                                header['function_code'] = data[0]
                                response = request_handler({
                                    'header': header,
                                    'body': dissection
                                })
                                self.logger.debug(
                                    'MB:{} Header: {} Body:{}'.format(
                                        self.port, header, dissection))
                                self.logger.debug('MB:{} Request: {}'.format(
                                    self.port, hexlify(buffer + data)))
                                self.logger.debug(
                                    'MB:{} Responding: {}'.format(
                                        self.port, hexlify(response)))
                                # add failures to the receiver
                                if not self.simulate_failures():
                                    continue
                                self._current_connection.sendall(response)
                                response_stop = time.time()
                                StatisticsCollector.increment_responses_sent()
                                StatisticsCollector.increment_avg_response(
                                    response_stop - response_start)
                        except IOError as e:
                            self.logger.warning(
                                'An IO error occurred when reading the socket {}'
                                .format(e))
                            self.logger.debug('Closing connection')
                            self._current_connection.close()
                            StatisticsCollector.increment_socket_errors()
                            break
            self.done.set()

    def _start_server_udp(self, request_handler: Callable) -> None:
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if self.localhost:
                s.bind(('localhost', self.port))
                self.logger.info('Starting UDP server at localhost:{}'.format(
                    self.port))
            else:
                s.bind((socket.gethostname(), self.port))
                self.logger.debug('Starting UDP server at {}:{}'.format(
                    socket.gethostname(), self.port))
            while not self.stop.is_set():
                try:
                    if self.failures.get('disconnected', False):
                        sleep(1)
                        continue
                    buffer, address = s.recvfrom(256)
                    self.logger.debug(
                        'Message received from: {}'.format(address))
                    StatisticsCollector.increment_packets_received()
                    response_start = time.time()
                    if buffer == b'' or len(buffer) <= 0:
                        self.logger.debug(
                            'Initial read was empty, peer connection was likely closed'
                        )
                        continue
                    header = buffer[:7]
                    header = modbusdecoder.dissect_header(header)
                    length = header['length']
                    if length == 0:
                        self.logger.debug('Length 0 message received')
                        continue
                    data = buffer[7:7 + length - 1]
                    is_error, dissection = self._dissect_packet(data)
                    if is_error:
                        self.logger.debug(
                            'An error was found in the modbus request {}'.
                            format(dissection))
                        self.logger.debug(
                            'Header appears like: {}'.format(header))
                        self.logger.debug('Buffer: {}'.format(hexlify(buffer)))
                        s.sendto(dissection, address)
                        response_stop = time.time()
                        StatisticsCollector.increment_avg_response(
                            response_stop - response_start)
                        StatisticsCollector.increment_error_packets_sent()
                        StatisticsCollector.increment_responses_sent()
                        continue
                    else:
                        dissection['type'] = 'request'
                        header['function_code'] = data[0]
                        response = request_handler({
                            'header': header,
                            'body': dissection
                        })
                        # add failures to the receiver
                        if not self.simulate_failures():
                            continue
                        s.sendto(response, address)
                        response_stop = time.time()
                        StatisticsCollector.increment_avg_response(
                            response_stop - response_start)
                        StatisticsCollector.increment_responses_sent()
                        self.logger.debug('MB:{} Request: {}'.format(
                            self.port, hexlify(buffer[:7 + length])))
                        self.logger.debug('MB:{} Header: {} Body:{}'.format(
                            self.port, header, dissection))
                        self.logger.debug('MB:{} Responding: {}'.format(
                            self.port, hexlify(response)))
                except IOError as e:
                    self.logger.warning(
                        'An IO error occurred with the socket {}'.format(e))
                    StatisticsCollector.increment_socket_errors()
                    continue
        self.done.set()

    # Return False to not respond
    def simulate_failures(self):
        with self.lock:
            if self.failures.get('stop-responding', False):
                self.logger.info('MB:{} Simulating no-response'.format(
                    self.port))
                return False
            elif self.failures.get('flake-response'):
                val = random.choice([1, 2, 3])
                if val == 1:
                    upper_bound = self.failures['flake-response']
                    sleep_time = random.randint(0, upper_bound) * 0.01
                    self.logger.info(
                        'MB:{} Simulating flake-response "delayed" {}ms'.
                        format(self.port, sleep_time))
                    time.sleep(sleep_time)
                elif val == 2:
                    self.logger.info(
                        'MB:{} Simulating flake-response "no-response"'.format(
                            self.port))
                    return False
            elif self.failures.get('delay-response', False):
                upper_bound = self.failures['delay-response']
                sleep_time = random.randint(0, upper_bound) * 0.01
                self.logger.info('MB:{} Simulating delay-response {}ms'.format(
                    self.port, sleep_time))
                time.sleep(sleep_time)
            return True

    def set_failures(self, failures):
        with self.lock:
            self.failures = failures

    '''
        Starts the Modbus server and listens for packets over a TCP/IP connection. By default it will bind to
        localhost at a port specified in the constructor. Upon receiving a modbus message it will decode the header
        and send the function code and data to the _dissect_packet function for further processing. Error packets
        lead to an immediate response with an error code, while valid requests are sent back to the request handler.
    '''

    def start_server(self, request_handler: Callable) -> None:
        if self.socket_type == socket.SOCK_STREAM:
            self._start_server_tcp(request_handler)
        else:
            self._start_server_udp(request_handler)

    '''
        Breaks the server out of it's blocking accept or recv calls and sets the stop flag.
        In order to do this the method uses a 'dummy' connection to break the blocking call.
        It then sends a 'dummy' message that will lead to the method dropping the request and exiting.
        **NOTE**: This is especially useful when the server is being run on it's own thread.
    '''

    def stop_server(self) -> None:
        self.logger.info('Stopping server now')
        self.stop.set()
        if self._current_connection:
            self._current_connection.close()
        sleep(.5)
        # In order to stop the server we have to interrupt
        # The blocking socket.accept()
        # We create a connection that sends a header for a 0 length
        # Packet
        if not self.done.is_set() and self.socket_type == socket.SOCK_STREAM:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                if self.localhost:
                    s.connect(('localhost', self.port))
                else:
                    s.connect((socket.gethostname(), self.port))
                s.sendall(b'\x00\x01\x00\x00\x00\x00\x00')
                s.close()
Пример #27
0
class on_policy_parallel_rollouts():
    def __init__(self, agent, env=None):
        self.lock = th.Lock()
        self.agent = agent
        # one and only
        self.farmer = farmer_class(self.agent.para_list)
        self.ep_num = 0
        self.total_steps = 0
        self.history_reward = []
        self.ep_value = {}
        self.relative_time = 0
        self.average_len_of_episode = self.agent.para_list["max_pathlength"]
        self.num_rollouts = int(self.agent.para_list["timesteps_per_batch"] /
                                self.average_len_of_episode)
        self.rollout_count = 0
        self.rollout_paths = []
        self.iteration = 0
        self.log_scalar_name_list = [
            'mean_reward', 'kl_div', 'entropy', 'surrogate_loss', 'value_loss'
        ]
        self.log_scalar_type_list = [
            tf.float32, tf.float32, tf.float32, tf.float32, tf.float32
        ]
        self.logger = Logger(self.agent.session,
                             self.agent.para_list["log_path"] + 'train',
                             self.log_scalar_name_list,
                             self.log_scalar_type_list)
        self.write_log = self.logger.create_scalar_log_method()
        self.start_time = time.time()

    def refarm(self):  # most time no use
        del self.farmer
        self.farmer = farmer_class()

    def fit(self):
        # calculate real leaning rate with a linear schedule
        learning_rate = learning_rate_schedule(self.agent.para_list,
                                               self.total_steps)

        # vars: update_old_policy_time, train_time, fit_time, surrogate_loss, kl_after, entropy_after, value_loss
        vars = self.agent.fit(self.rollout_paths, learning_rate)

        # update print info:
        self.ep_value["learning_rate"] = (learning_rate[0] +
                                          learning_rate[1]) / 2
        self.ep_value["episode_nums"] = self.num_rollouts
        self.ep_value["episode_steps"] = sum(
            [len(path["rewards"]) for path in self.rollout_paths])
        self.ep_value["episode_time"] = sum(
            [path["ep_time"]
             for path in self.rollout_paths]) / self.num_rollouts
        self.ep_value["episode_reward"] = sum(
            [path["rewards"].sum()
             for path in self.rollout_paths]) / self.num_rollouts
        self.ep_value["average_steps"] = int(self.average_len_of_episode)
        self.ep_value["step_time"] = sum(
            [path["step_time"]
             for path in self.rollout_paths]) / self.num_rollouts
        # copy fit info to the ep_value
        self.ep_value['update_time'], self.ep_value[
            'train_time'], self.ep_value['iter_time'] = vars[0], vars[1], vars[
                2]
        self.ep_value['surrogate_loss'], self.ep_value['kl_after'] = vars[
            3], vars[4]
        self.ep_value['entropy_after'], self.ep_value['value_loss'] = vars[
            5], vars[6]
        self.relative_time = (time.time() - self.start_time) / 60

        # write_log with tensorboard: ['reward', 'kl_div', 'entropy', 'surrogate_loss', 'value_loss']
        self.write_log([
            self.ep_value['episode_reward'], self.ep_value['kl_after'],
            self.ep_value['entropy_after'], self.ep_value['surrogate_loss'],
            self.ep_value['value_loss']
        ], self.iteration)

        # print iteration information:
        on_policy_iteration_print(self.iteration, self.total_steps,
                                  self.relative_time, self.ep_value)
        self.history_reward.append(self.ep_value["episode_reward"])

    def rollout_an_episode(self, env):
        global graph
        with graph.as_default():
            path = rollout_an_episode(self.agent, env, self.lock)
        self.lock.acquire()
        self.ep_num += 1
        self.rollout_count += 1
        self.rollout_paths.append(path)
        self.total_steps += len(path["actions"])
        self.lock.release()
        env.rel()

    def rollout_if_available(self):
        while True:
            remote_env = self.farmer.acq_env()  # call for a remote_env
            if remote_env is False:  # no free environment
                # time.sleep(0.1)
                pass
            else:
                t = th.Thread(target=self.rollout_an_episode,
                              args=(remote_env, ),
                              daemon=True)
                t.start()
                break

    def rollout(self):
        while self.total_steps < self.agent.para_list["n_steps"]:
            if self.num_rollouts == 0:
                raise RuntimeError('wrong, div 0!!!')
            self.iteration += 1
            for i in range(self.num_rollouts):
                self.rollout_if_available()

            while self.rollout_count != self.num_rollouts:
                pass

            if (self.iteration + 1) % 10 == 0:
                self.agent.saveModel(0)

            self.average_len_of_episode = sum(
                [len(path["rewards"])
                 for path in self.rollout_paths]) / self.num_rollouts
            if self.average_len_of_episode == 0:
                raise RuntimeError('wrong, div 0!!!')
            self.fit()
            self.rollout_count = 0
            self.rollout_paths = []
            self.num_rollouts = int(
                self.agent.para_list["timesteps_per_batch"] /
                self.average_len_of_episode)

        self.agent.saveModel(0)
        return self.history_reward
Пример #28
0
class Trainer(object, metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args
        # ensure_path(
        #     self.args.save_path,
        #     scripts_to_save=['model/models', 'model/networks', __file__],
        # )
        self.logger = Logger(args, osp.join(args.save_path))

        self.train_step = 0
        self.train_epoch = 0
        self.max_steps = args.episodes_per_epoch * args.max_epoch
        self.dt, self.ft = Averager(), Averager()
        self.bt, self.ot = Averager(), Averager()
        self.timer = Timer()

        # train statistics
        self.trlog = {}
        self.trlog['max_acc'] = 0.0
        self.trlog['max_acc_epoch'] = 0
        self.trlog['max_acc_interval'] = 0.0

        # For tst
        if args.tst_free:
            self.trlog['max_tst_criterion'] = 0.0
            self.trlog['max_tst_criterion_interval'] = 0.
            self.trlog['max_tst_criterion_epoch'] = 0
            self.trlog['tst_criterion'] = args.tst_criterion

    @abc.abstractmethod
    def train(self):
        pass

    @abc.abstractmethod
    def evaluate(self, data_loader):
        pass

    @abc.abstractmethod
    def evaluate_test(self, data_loader):
        pass

    @abc.abstractmethod
    def final_record(self):
        pass

    def print_metric_summaries(self, metric_summaries, prefix='\t'):
        for key, (mean, std) in metric_summaries.items():
            print('{}{}: {:.4f} +/- {:.4f}'.format(prefix, key, mean, std))

    def log_metric_summaries(self, metric_summaries, epoch, prefix=''):
        for key, (mean, std) in metric_summaries.items():
            self.logger.add_scalar('{}{}'.format(prefix, key), mean, epoch)

    def try_evaluate(self, epoch):
        args = self.args
        if self.train_epoch % args.eval_interval == 0:

            if not args.tst_free:
                vl, va, vap = self.evaluate(self.val_loader)
                self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
                self.logger.add_scalar('val_acc', float(va), self.train_epoch)
                print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                    epoch, vl, va, vap))
            else:
                vl, va, vap, metrics = self.evaluate(self.val_loader)
                self.logger.add_scalar('val_loss', float(vl), self.train_epoch)
                self.logger.add_scalar('val_acc', float(va), self.train_epoch)
                print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(
                    epoch, vl, va, vap))
                self.print_metric_summaries(metrics, prefix='\tval_')
                self.log_metric_summaries(metrics, epoch=epoch, prefix='val_')

            if va >= self.trlog['max_acc']:
                self.trlog['max_acc'] = va
                self.trlog['max_acc_interval'] = vap
                self.trlog['max_acc_epoch'] = self.train_epoch
                self.save_model('max_acc')

            # Probably a different criterion for TST -> optimize here.
            if args.tst_free and args.tst_criterion:
                assert args.tst_criterion in metrics, 'Criterion {} not found in {}'.format(
                    args.tst_criterion, metrics.keys())
                criterion, criterion_interval = metrics[args.tst_criterion]
                if criterion >= self.trlog['max_tst_criterion']:
                    self.trlog['max_tst_criterion'] = criterion
                    self.trlog[
                        'max_tst_criterion_interval'] = criterion_interval
                    self.trlog['max_tst_criterion_epoch'] = self.train_epoch
                    self.save_model('max_tst_criterion')
                    print(
                        'Found new best model at Epoch {} : Validation {} = {:.4f} +/- {:4f}'
                        .format(self.train_epoch, args.tst_criterion,
                                criterion, criterion_interval))

    def try_logging(self, tl1, tl2, ta, tg=None):
        args = self.args
        if self.train_step % args.log_interval == 0:
            print(
                'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}'
                .format(self.train_epoch, self.train_step, self.max_steps,
                        tl1.item(), tl2.item(), ta.item(),
                        self.optimizer.param_groups[0]['lr']))
            self.logger.add_scalar('train_total_loss', tl1.item(),
                                   self.train_step)
            self.logger.add_scalar('train_loss', tl2.item(), self.train_step)
            self.logger.add_scalar('train_acc', ta.item(), self.train_step)
            if tg is not None:
                self.logger.add_scalar('grad_norm', tg.item(), self.train_step)
            print('data_timer: {:.2f} sec, '     \
                  'forward_timer: {:.2f} sec,'   \
                  'backward_timer: {:.2f} sec, ' \
                  'optim_timer: {:.2f} sec'.format(
                        self.dt.item(), self.ft.item(),
                        self.bt.item(), self.ot.item())
                  )
            self.logger.dump()

    def save_model(self, name):
        torch.save(dict(params=self.model.state_dict()),
                   osp.join(self.args.save_path, name + '.pth'))

    def __str__(self):
        return "{}({})".format(self.__class__.__name__,
                               self.model.__class__.__name__)