コード例 #1
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
    def group_message(self, message_content):
        group_name = message_content.get(MessageFields.GROUP_NAME)
        sender = message_content.get(MessageFields.SENDER_USERNAME)
        text = message_content.get(MessageFields.MESSAGE_TEXT)

        if group_name in self.groups:
            for group_member in self.groups[group_name]:
                if group_member != sender:
                    user = self._get_user_info(group_member)
                    message = {
                        MessageFields.MESSAGE_ID: MessageId.GROUP_MESSAGE,
                        MessageFields.MESSAGE_CONTENT: {
                            MessageFields.SENDER_USERNAME: sender,
                            MessageFields.MESSAGE_TEXT: text,
                            MessageFields.GROUP_NAME: group_name
                        }
                    }
                    message_to_send = MessageBuilder.make_sendable(message)
                    self.writing_connection.send_message(
                        message_to_send, user.ip, user.port)
        else:
            user = self._get_user_info(sender)
            message = {
                MessageFields.MESSAGE_ID: MessageId.RECEIVE_MESSAGE,
                MessageFields.MESSAGE_CONTENT: {
                    MessageFields.SENDER_USERNAME: '******',
                    MessageFields.MESSAGE_TEXT: 'no such group'
                }
            }
            message_to_send = MessageBuilder.make_sendable(message)
            self.writing_connection.send_message(message_to_send, user.ip,
                                                 user.port)
コード例 #2
0
    def send_login_message(self):
        name = ''
        while not name:
            name = input('enter your username: '******'online users: {}'.format(str(online_users)))

        self.broker_address = addr
コード例 #3
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def set_battery_capacity(self):
     try:
         mb = MessageBuilder(ToUlink.SET_BATTERY_CAPACITY)
         battery_capacity = float(self.ui_root.settings.battery_capacity.text)
         mb.add_float(battery_capacity)
         self.send(mb.to_bytes())
     except Exception as e:
         print (e)
         print ('WARNING: wrong battery capacity value.')
コード例 #4
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def set_balance(self):
     try:
         mb = MessageBuilder(ToUlink.SET_BALANCE)
         balance = int(self.ui_root.settings.box_balance.text)
         mb.add_uint32(balance)
         self.send(mb.to_bytes())
     except Exception as e:
         print (e)
         print ('WARNING: invalid balance value.')
コード例 #5
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
    def create_group(self, message_content):
        group_members = message_content.get(MessageFields.GROUP_MEMBERS, [])
        group_name = message_content.get(MessageFields.GROUP_NAME)

        # Add the user that requested the creation of the group
        self.groups[group_name].add(
            message_content.get(MessageFields.SENDER_USERNAME))

        for group_member in group_members:
            if group_member in self.users_dict:
                self.groups[group_name].add(group_member)

            print(self.groups)

        # Send a message to every member of the newly created group
        for group_member in self.groups[group_name]:
            user = self._get_user_info(group_member)
            text = 'you\'re in group: {}. with members: {}'.format(
                group_name, str(self.groups[group_name]))
            message = {
                MessageFields.MESSAGE_ID: MessageId.RECEIVE_MESSAGE,
                MessageFields.MESSAGE_CONTENT: {
                    MessageFields.SENDER_USERNAME: '******',
                    MessageFields.MESSAGE_TEXT: text,
                }
            }
            message_to_send = MessageBuilder.make_sendable(message)
            self.writing_connection.send_message(message_to_send, user.ip,
                                                 user.port)
コード例 #6
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def override_demand_response_state(self):
     mb = MessageBuilder(ToUlink.OVERRIDE_DEMAND_REPONSE)
     state_to_command = {
         'none':   [0, 0],
         'green':  [1, 0],
         'yellow': [1, 1],
         'red':    [1, 2],
         'off':    [1, 3],
         'on':     [1, 4],
     }
     target_state = self.ui_root.debug_panel.dr_state.text
     if target_state not in state_to_command.keys():
         print ('%s not acceptable. Must be one of %s..' % (target_state, state_to_command.keys()))
         return
     for b in state_to_command[target_state]:
         mb.add_byte(b)
     self.send(mb.to_bytes())
コード例 #7
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
    def send_message(self, message_content):
        user = self._get_user_info(
            message_content.get(MessageFields.RECEIVER_USERNAME))
        if user:
            if user.is_online:
                message = {
                    MessageFields.MESSAGE_ID: MessageId.RECEIVE_MESSAGE,
                    MessageFields.MESSAGE_CONTENT: {
                        MessageFields.SENDER_USERNAME:
                        message_content.get(MessageFields.SENDER_USERNAME),
                        MessageFields.MESSAGE_TEXT:
                        message_content.get(MessageFields.MESSAGE_TEXT),
                    }
                }
                message_to_send = MessageBuilder.make_sendable(message)
                self.writing_connection.send_message(message_to_send, user.ip,
                                                     user.port)
            else:
                text = '{} is offline. The user will receive the message when the user comes online'.format(
                    message_content.get(MessageFields.RECEIVER_USERNAME))
                message = {
                    MessageFields.MESSAGE_ID: MessageId.RECEIVE_MESSAGE,
                    MessageFields.MESSAGE_CONTENT: {
                        MessageFields.SENDER_USERNAME: '******',
                        MessageFields.MESSAGE_TEXT: text
                    }
                }
                message_to_send = MessageBuilder.make_sendable(message)
                sender = self._get_user_info(
                    message_content.get(MessageFields.SENDER_USERNAME))
                self.writing_connection.send_message(message_to_send,
                                                     sender.ip, sender.port)

                self.unsent_messages[message_content.get(
                    MessageFields.RECEIVER_USERNAME)].append({
                        MessageFields.SENDER_USERNAME:
                        message_content.get(MessageFields.SENDER_USERNAME),
                        MessageFields.MESSAGE_TEXT:
                        message_content.get(MessageFields.MESSAGE_TEXT)
                    })
                print(self.unsent_messages)
        else:
            warnings.warn('no such user exists')
コード例 #8
0
    def reading_sel(self):
        data, _ = self.reading_connection.receive_message()
        message = MessageBuilder.make_readable(data)

        message_type = message.get(MessageFields.MESSAGE_ID)
        message_content = message.get(MessageFields.MESSAGE_CONTENT, {})

        try:
            getattr(self, message_type)(message_content)
        except AttributeError:
            pass
コード例 #9
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
    def _process_login(self):
        data, addr = self.broadcasting_connection.receive_message()
        message = MessageBuilder.make_readable(data)
        DebugMessage.debug(message, addr)

        ip = addr[0]
        message_content = message.get(MessageFields.MESSAGE_CONTENT, {})
        username = message_content.get(MessageFields.SENDER_USERNAME)
        port = message_content.get(MessageFields.SENDER_PORT)
        self._add_new_user(username, ip, port)
        online_users = self.get_list()
        message = {
            MessageFields.MESSAGE_ID: MessageId.ACKNOWLEDGE_LOG_IN,
            MessageFields.MESSAGE_CONTENT: {
                MessageFields.USER_LIST: online_users
            }
        }
        message_to_send = MessageBuilder.make_sendable(message)
        self.writing_connection.send_message(message_to_send, ip, port)

        self._send_missed_messages(username)
コード例 #10
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
    def _process_receive(self):
        data, addr = self.reading_connection.receive_message()
        message = MessageBuilder.make_readable(data)
        DebugMessage.debug(message, addr)

        message_type = message.get(MessageFields.MESSAGE_ID)
        message_content = message.get(MessageFields.MESSAGE_CONTENT, {})

        try:
            getattr(self, message_type)(message_content)
        except AttributeError:
            print(message)
            warnings.warn('oups, seems like a rogue message')
コード例 #11
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
 def get_online_users(self, message_content):
     address = self._get_user_info(
         message_content.get(MessageFields.SENDER_USERNAME))
     online_users = self.get_list()
     if address:
         message = {
             MessageFields.MESSAGE_ID: MessageId.ONLINE_USERS,
             MessageFields.MESSAGE_CONTENT: {
                 MessageFields.USER_LIST: online_users
             }
         }
         message_to_send = MessageBuilder.make_sendable(message)
         self.writing_connection.send_message(message_to_send, address[0],
                                              address[1])
     else:
         warnings.warn('no such user exists')
コード例 #12
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def set_state_of_charge(self):
     try:
         mb = MessageBuilder(ToUlink.SET_STATE_OF_CHARGE)
         state = float(self.ui_root.settings.state_of_charge.text)
         uncertainty = float(self.ui_root.settings.uncertainty_of_charge.text)
         mb.add_float(state)
         mb.add_float(uncertainty)
         self.send(mb.to_bytes())
     except Exception as e:
         print (e)
         print ('WARNING: wrong state of charge value.')
コード例 #13
0
ファイル: broker.py プロジェクト: Tolea86/PAD-Lab1
    def _send_missed_messages(self, username):
        missed_messages = self.unsent_messages.pop(username, [])
        user = self._get_user_info(username)

        for missed_message in missed_messages:
            message = {
                MessageFields.MESSAGE_ID: MessageId.RECEIVE_MESSAGE,
                MessageFields.MESSAGE_CONTENT: {
                    MessageFields.SENDER_USERNAME:
                    missed_message.get(MessageFields.SENDER_USERNAME),
                    MessageFields.MESSAGE_TEXT:
                    missed_message.get(MessageFields.MESSAGE_TEXT),
                }
            }
            message_to_send = MessageBuilder.make_sendable(message)
            self.writing_connection.send_message(message_to_send, user.ip,
                                                 user.port)
コード例 #14
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
    def set_balance_update(self):
        try:
            mb = MessageBuilder(ToUlink.SET_BALANCE_UPDATE)

            balance_update_hours   = int(self.ui_root.settings.balance_update_hours.text)
            balance_update_minutes = int(self.ui_root.settings.balance_update_minutes.text)
            balance_update_ammount = int(self.ui_root.settings.balance_update_ammount.text)

            assert(balance_update_hours in range(24))
            assert(balance_update_minutes in range(60))
            assert(balance_update_ammount > 0)

            mb.add_int(balance_update_hours)
            mb.add_int(balance_update_minutes)
            mb.add_uint32(balance_update_ammount)

            self.send(mb.to_bytes())
        except Exception as e:
            print (e)
            print ('WARNING: wrong values for balance update')
コード例 #15
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
    def read_eeprom(self, mem_type):
        mem_types = [ 'float', 'uint32', 'byte']
        assert mem_type in mem_types
        mem_type_idx = mem_types.index(mem_type)
        addr = None
        try:
            addr = int(self.ui_root.debug_panel.eeprom_addr.text)
        except Exception as e:
            print (e)
            print ('WARNING: invalid eeprom_addr')

        mb = MessageBuilder(ToUlink.READ_EEPROM)
        mb.add_byte(mem_type_idx)
        mb.add_uint32(addr)
        self.send(mb.to_bytes())
コード例 #16
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def set_uid_node_type(self):
     mb = MessageBuilder(ToUlink.SET_UID_NODE_TYPE)
     try:
         uid = int(self.ui_root.settings.box_uid.text)
         assert 0<= uid and uid <= 255
         node_type = self.ui_root.settings.box_node_type.text
         node_type = node_type.upper()
         assert len(node_type) == 1 and node_type in ['A', 'B']
         node_type = ord(node_type)
         mb.add_byte(uid)
         mb.add_byte(node_type)
         self.send(mb.to_bytes())
     except Exception as e:
         print (e)
         print ('WARNING: invalid uid or node type value.')
コード例 #17
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
    def set_thresholds(self):
        try:
            mb = MessageBuilder(ToUlink.SET_THRESHOLDS)

            off_threshold = float(self.ui_root.settings.off_threshold.text)
            red_threshold = float(self.ui_root.settings.red_threshold.text)
            yellow_threshold = float(self.ui_root.settings.yellow_threshold.text)

            mb.add_float(off_threshold)
            mb.add_float(red_threshold)
            mb.add_float(yellow_threshold)

            self.send(mb.to_bytes())
        except Exception as e:
            print (e)
            print ('WARNING: wrong state of charge value.')
コード例 #18
0
ファイル: eval.py プロジェクト: andrewssobral/neuralnet
  def eval_loop(self, last_global_step):
    """Run the evaluation loop once."""

    latest_checkpoint, global_step = self.get_checkpoint(
      last_global_step)
    logging.info("latest_checkpoint: {}".format(latest_checkpoint))

    if latest_checkpoint is None or global_step == last_global_step:
      time.sleep(self.wait)
      return last_global_step

    with tf.Session(config=self.config) as sess:
      logging.info("Loading checkpoint for eval: {}".format(latest_checkpoint))

      # Restores from checkpoint
      self.saver.restore(sess, latest_checkpoint)
      sess.run(tf.local_variables_initializer())

      epoch = get_epoch(
                global_step,
                FLAGS.train_num_gpu,
                FLAGS.train_batch_size,
                self.reader.n_train_files)

      fetches = OrderedDict(
         loss_update_op=tf_get('loss_update_op'),
         acc_update_op=tf_get('acc_update_op'),
         loss=tf_get('loss'),
         accuracy=tf_get('accuracy'))

      while True:
        try:

          batch_start_time = time.time()
          values = sess.run(list(fetches.values()))
          values = dict(zip(fetches.keys(), values))
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = self.batch_size / seconds_per_batch

          message = MessageBuilder()
          message.add('epoch', epoch, format='.2f')
          message.add('step', global_step)
          message.add('accuracy', values['accuracy'], format='.5f')
          message.add('avg loss', values['loss'], format='.5f')
          message.add('imgs/sec', examples_per_second, format='.0f')
          logging.info(message.get_message())

        except tf.errors.OutOfRangeError:

          if self.best_accuracy is None or self.best_accuracy < values['accuracy']:
            self.best_global_step = global_step
            self.best_accuracy = values['accuracy']

          make_summary("accuracy", values['accuracy'], self.summary_writer, global_step)
          make_summary("loss", values['loss'], self.summary_writer, global_step)
          make_summary("epoch", epoch, self.summary_writer, global_step)
          self.summary_writer.flush()

          message = MessageBuilder()
          message.add('final: epoch', epoch, format='.2f')
          message.add('step', global_step)
          message.add('accuracy', values['accuracy'], format='.5f')
          message.add('avg loss', values['loss'], format='.5f')
          logging.info(message.get_message())
          logging.info("Done with batched inference.")

          if self.stopped_at_n:
           self.counter += 1

          break

      return global_step
コード例 #19
0
ファイル: eval.py プロジェクト: andrewssobral/neuralnet
  def eval_attack(self):
    """Run the evaluation under attack."""

    best_checkpoint, global_step = self.get_best_checkpoint()

    epoch = get_epoch(
             global_step,
             FLAGS.train_num_gpu,
             FLAGS.train_batch_size,
             self.reader.n_train_files)

    with tf.Session(config=self.config) as sess:
      logging.info("Evaluation under attack:")

      # Restores from checkpoint
      self.saver.restore(sess, best_checkpoint)
      sess.run(tf.local_variables_initializer())

      # pass session to attack class for Carlini Attack
      self.attack.sess = sess

      fetches = OrderedDict(
         loss_update_op=tf_get('loss_update_op'),
         acc_update_op=tf_get('acc_update_op'),
         loss_adv_update_op=tf_get('loss_adv_update_op'),
         acc_adv_update_op=tf_get('acc_adv_update_op'),
         mean_norm_l1_update_op=tf_get('mean_norm_l1_update_op'),
         mean_norm_l2_update_op=tf_get('mean_norm_l2_update_op'),
         mean_norm_linf_update_op=tf_get('mean_norm_linf_update_op'),
         images=tf_get('images_batch'),
         images_adv=tf_get('images_adv_batch'),
         loss=tf_get('loss'),
         accuracy=tf_get('accuracy'),
         predictions=tf_get('predictions'),
         predictions_adv=tf_get('predictions_adv'),
         loss_adv=tf_get('loss_adv'),
         accuracy_adv=tf_get('accuracy_adv'),
         labels_batch=tf_get('labels_batch'),
         mean_l1=tf_get('mean_norm_l1'),
         mean_l2=tf_get('mean_norm_l2'),
         mean_linf=tf_get('mean_norm_linf'))

      count = 0
      dump = DumpFiles(self.train_dir)
      while True:
        try:

          batch_start_time = time.time()
          values = sess.run(list(fetches.values()))
          values = dict(zip(fetches.keys(), values))
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = self.batch_size / seconds_per_batch
          count += self.batch_size

          # dump images and images_adv
          if FLAGS.dump_files:
            dump.files(values)

          message = MessageBuilder()
          message.add('', [count, self.reader.n_test_files])
          message.add('acc img/adv',
                      [values['accuracy'], values['accuracy_adv']], format='.5f')
          message.add('avg loss', [values['loss'], values['loss_adv']], format='.5f')
          message.add('imgs/sec', examples_per_second, format='.3f')
          if FLAGS.eval_under_attack:
            norms_mean = [values['mean_l1'], values['mean_l2'], values['mean_linf']]
            message.add('l1/l2/linf mean', norms_mean, format='.2f')
          logging.info(message.get_message())

        except tf.errors.OutOfRangeError:

          message = MessageBuilder()
          message.add('Final: images/adv',
                      [values['accuracy'], values['accuracy_adv']], format='.5f')
          message.add('avg loss', [values['loss'], values['loss_adv']], format='.5f')
          logging.info(message.get_message())
          logging.info("Done evaluation of adversarial examples.")
          break

    return values['accuracy'], values['accuracy_adv']
コード例 #20
0
 def _send_message(self, message):
     message_to_send = MessageBuilder.make_sendable(message)
     self.writing_connection.send_message(message_to_send, self.broker_address[0])
コード例 #21
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def sync_time(self):
     ts = int(time.time())
     mb = MessageBuilder(ToUlink.SET_TIME)
     mb.add_uint32(ts)
     self.send(mb.to_bytes())
コード例 #22
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def get_connected_nodes(self):
     mb = MessageBuilder(ToUlink.GET_CONNECTED_NODES)
     self.send(mb.to_bytes())
コード例 #23
0
ファイル: data_logger.py プロジェクト: nivwusquorum/microgrid
    def pull(self):
        MSG_TIMEOUT = 0.0001
        BATCH_TIMEOUT = 0.5
        BATCH_SIZE = 4

        while self.schema_name is None and not self.please_stop:
            mb = MessageBuilder(ToUlink.DATA_LOGGER)
            mb.add_byte(DataLoggerMessages.EXTRACT_GENERAL)
            self.send(mb.to_bytes())
            time.sleep(BATCH_TIMEOUT)

        self.log("Found %d entries (in %d columns)" % (self.n_entries, self.n_columns))

        for line in self.general_info():
            self.log(line)

        self.column_done = [False for _ in range(self.n_columns)]
        self.column_name = [None for _ in range(self.n_columns)]
        self.column_type = [None for _ in range(self.n_columns)]

        while not all(self.column_done) and not self.please_stop:
            for column in range(self.n_columns):
                if not self.column_done[column]:
                    mb = MessageBuilder(ToUlink.DATA_LOGGER)
                    mb.add_byte(DataLoggerMessages.EXTRACT_COLUMN)
                    mb.add_byte(column)
                    self.send(mb.to_bytes())
                    time.sleep(MSG_TIMEOUT)
            time.sleep(BATCH_TIMEOUT)

        self.log(self.column_line())

        self.entry_done = [False for _ in range(self.n_entries)]
        self.entry = [None for _ in range(self.n_entries)]

        for batch_start in range(0, self.n_entries, BATCH_SIZE):
            if VERBOSE: print ('batch_start')
            batch_end = min(batch_start + BATCH_SIZE, self.n_entries)
            while not all(self.entry_done[batch_start:batch_end]) and not self.please_stop:
                for entry in range(batch_start, batch_end):
                    if not self.entry_done[entry]:
                        if VERBOSE: print ('msg_start')
                        mb = MessageBuilder(ToUlink.DATA_LOGGER)
                        mb.add_byte(DataLoggerMessages.EXTRACT_DATA)
                        mb.add_uint32(entry)
                        self.send(mb.to_bytes())
                        if VERBOSE: print ('msg_end')
                        time.sleep(MSG_TIMEOUT)
                time.sleep(BATCH_TIMEOUT)

            for entry in range(batch_start, batch_end):
                self.log(self.entry_line(entry))
            if VERBOSE: print ('batch_end')
コード例 #24
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def cron_stats(self):
     mb = MessageBuilder(ToUlink.CRON_STATS)
     self.send(mb.to_bytes())
コード例 #25
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def print_data_logs(self):
     mb = MessageBuilder(ToUlink.PRINT_DATA_LOGS)
     self.send(mb.to_bytes())
コード例 #26
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def print_local_time(self):
     mb = MessageBuilder(ToUlink.PRINT_LOCAL_TIME)
     self.send(mb.to_bytes())
コード例 #27
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def test_leds(self):
     mb = MessageBuilder(ToUlink.TEST_LEDS)
     self.send(mb.to_bytes())
コード例 #28
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def send_ping(self):
     mb = MessageBuilder(ToUlink.PING)
     self.send(mb.to_bytes())
コード例 #29
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def factory_reset(self):
     mb = MessageBuilder(ToUlink.FACTORY_RESET)
     self.send(mb.to_bytes())
コード例 #30
0
ファイル: data_logger.py プロジェクト: nivwusquorum/microgrid
 def status(self):
     mb = MessageBuilder(ToUlink.DATA_LOGGER)
     mb.add_byte(DataLoggerMessages.GET_STATUS)
     self.send(mb.to_bytes())
コード例 #31
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def update_settings(self, *largs):
     mb = MessageBuilder(ToUlink.GET_SETTINGS)
     self.send(mb.to_bytes())
コード例 #32
0
ファイル: data_logger.py プロジェクト: nivwusquorum/microgrid
 def reset(self):
     mb = MessageBuilder(ToUlink.DATA_LOGGER)
     mb.add_byte(DataLoggerMessages.RESET)
     self.send(mb.to_bytes())
コード例 #33
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def get_memory(self):
     mb = MessageBuilder(ToUlink.GET_MEMORY)
     self.send(mb.to_bytes())
コード例 #34
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model and exists(self.train_dir):
            self.remove_training_directory(self.train_dir)

        pp = pprint.PrettyPrinter(indent=2, compact=True)
        logging.info(pp.pformat(FLAGS.values()))

        model_flags_dict = FLAGS.to_json()
        log_folder = '{}_logs'.format(self.train_dir)
        flags_json_path = join(log_folder, "model_flags.json")
        if not exists(flags_json_path):
            # Write the file.
            with open(flags_json_path, "w") as fout:
                fout.write(model_flags_dict)

        target, device_fn = self.start_server_if_distributed()
        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:
            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.train.get_global_step()
                loss = tf.get_collection("loss")[0]
                logits = tf.get_collection("logits")[0]
                labels = tf.get_collection("labels")[0]
                learning_rate = tf.get_collection("learning_rate")[0]
                train_op = tf.get_collection("train_op")[0]
                summary_op = tf.get_collection("summary_op")[0]
                init_op = tf.global_variables_initializer()

                gradients_norm = tf.get_collection("gradients_norm")[0]

            scaffold = tf.train.Scaffold(
                saver=saver,
                init_op=init_op,
                summary_op=summary_op,
            )

            hooks = [
                tf.train.NanTensorHook(loss),
                tf.train.StopAtStepHook(num_steps=self.max_steps),
            ]

            session_args = dict(
                is_chief=self.is_master,
                scaffold=scaffold,
                checkpoint_dir=FLAGS.train_dir,
                hooks=hooks,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps,
                save_summaries_steps=10,
                save_summaries_secs=None,
                log_step_count_steps=0,
                config=self.config,
            )

            logging.info("Start training")
            with tf.train.MonitoredTrainingSession(**session_args) as sess:

                summary_writer = tf.summary.FileWriterCache.get(
                    FLAGS.train_dir)

                if FLAGS.profiler:
                    profiler = tf.profiler.Profiler(sess.graph)

                global_step_val = 0
                while not sess.should_stop():

                    make_profile = False
                    profile_args = {}

                    if global_step_val % 1000 == 0 and FLAGS.profiler:
                        make_profile = True
                        run_meta = tf.RunMetadata()
                        profile_args = {
                            'options':
                            tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            'run_metadata':
                            run_meta
                        }

                    fetches = OrderedDict(train_op=train_op,
                                          global_step=global_step,
                                          loss=loss,
                                          learning_rate=learning_rate,
                                          logits=logits,
                                          labels=labels)

                    if gradients_norm != 0:
                        fetches['gradients_norm'] = gradients_norm
                    else:
                        grad_norm_val = 0

                    batch_start_time = time.time()
                    values = sess.run(list(fetches.values()), **profile_args)
                    fetches_values = OrderedDict(zip(fetches.keys(), values))
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = self.batch_size / seconds_per_batch

                    global_step_val = fetches_values['global_step']
                    loss_val = fetches_values['loss']
                    learning_rate_val = fetches_values['learning_rate']
                    predictions_val = fetches_values['logits']
                    labels_val = fetches_values['labels']

                    if gradients_norm != 0:
                        grad_norm_val = fetches_values['gradients_norm']

                    if FLAGS.gradients['compute_hessian'] and global_step_val != 0 and \
                       global_step_val % FLAGS.gradients['hessian_every_n_step'] == 0:
                        compute_hessian_and_summary(sess, summary_writer,
                                                    global_step_val)

                    if make_profile and FLAGS.profiler:
                        profiler.add_step(global_step_val, run_meta)

                        # Profile the parameters of your model.
                        profiler.profile_name_scope(
                            options=(tf.profiler.ProfileOptionBuilder.
                                     trainable_variables_parameter()))

                        # Or profile the timing of your model operations.
                        opts = tf.profiler.ProfileOptionBuilder.time_and_memory(
                        )
                        profiler.profile_operations(options=opts)

                        # Or you can generate a timeline:
                        opts = (tf.profiler.ProfileOptionBuilder(
                            tf.profiler.ProfileOptionBuilder.time_and_memory(
                            )).with_step(global_step_val).with_timeline_output(
                                '~/profile.logs').build())
                        profiler.profile_graph(options=opts)

                    to_print = global_step_val % FLAGS.frequency_log_steps == 0
                    if (self.is_master and to_print) or global_step_val == 1:
                        epoch = ((global_step_val * self.batch_size) /
                                 self.reader.n_train_files)

                        message = MessageBuilder()
                        message.add("epoch", epoch, format="4.2f")
                        message.add("step",
                                    global_step_val,
                                    width=5,
                                    format=".0f")
                        message.add("lr", learning_rate_val, format=".6f")
                        message.add("loss", loss_val, format=".4f")
                        if "YT8M" in self.reader.__class__.__name__:
                            gap = eval_util.calculate_gap(
                                predictions_val, labels_val)
                            message.add("gap", gap, format=".3f")
                        message.add("imgs/sec",
                                    examples_per_second,
                                    width=5,
                                    format=".0f")
                        if FLAGS.gradients['perturbed_gradients']:
                            message.add("grad norm",
                                        grad_norm_val,
                                        format=".4f")
                        logging.info(message.get_message())

                # End training
                logging.info(
                    "{}: Done training -- epoch limit reached.".format(
                        task_as_string(self.task)))
                if FLAGS.profiler:
                    profiler.advise()
        logging.info("{}: Exited training loop.".format(
            task_as_string(self.task)))
コード例 #35
0
ファイル: controller.py プロジェクト: nivwusquorum/microgrid
 def reset_pic(self):
     mb = MessageBuilder(ToUlink.RESET_PIC)
     self.send(mb.to_bytes())