Ejemplo n.º 1
0
  def update_runtime_meta(self, msg):
    '''
    Write the runtime meta file for restore running
    '''
    RT_META_PATH = params.LOGGING_PATH / params.RT_META_FILE

    meta = self.msgfactory.create_runtime_meta()
    utils.msg_ud(meta, 'gidx', utils.msg_gt(msg, 'message_info|gidx'))
    utils.msg_ud(meta, 'tidx', utils.msg_gt(msg, 'message_info|tidx'))
    utils.msg_ud(meta, 'vidx', utils.msg_gt(msg, 'message_info|vidx'))

    with RT_META_PATH.open('w') as f:
      json.dump(meta, f)
Ejemplo n.º 2
0
  def logging_hv(self, msg):
    orig_covr_img = utils.msg_gt(msg, 'image|orig_covr')
    orig_hide_img = utils.msg_gt(msg, 'image|orig_hide')
    steg_img = utils.msg_gt(msg, 'image|steg')
    dcpt_covr_img = utils.msg_gt(msg, 'image|dcpt_covr')
    dcpt_hide_img = utils.msg_gt(msg, 'image|dcpt_hide')

    # Images saved to disk
    saved_ilabel = ['orig_covr', 'steg', 'dcpt_covr', 'orig_hide', 'dcpt_hide']
    saved_images = [orig_covr_img, steg_img, dcpt_covr_img, orig_hide_img, dcpt_hide_img]
    # Check not None
    saved_ilabel = [
        saved_ilabel[idx] for idx, images in enumerate(saved_images) if not images is None
    ]
    saved_images = [images for images in saved_images if not images is None]

    gmode, mode = params.GMODE, utils.msg_gt(msg, 'message_info|mode')

    if gmode == 'train':
      epoch = utils.msg_gt(msg, 'message_info|epoch')
      batch = utils.msg_gt(msg, 'message_info|batch')

      suffix_map = {
          'train': 't',
          'valid': 'v',
      }
      suffix = suffix_map[mode]

      im = utils.image_comp(saved_images)
      im = utils.norm2pil(im)
      im.save('{}/{:010d}_{:010d}_{}.jpg'.format(params.VISUAL_PATH, epoch, batch, suffix))
    elif gmode == 'inference':
      lidx = utils.msg_gt(msg, 'message_info|lidx')

      suffix = 'i'

      im = utils.image_comp(saved_images)
      im = utils.norm2pil(im)
      im.save('{}/{:010d}_{}.jpg'.format(params.VISUAL_PATH, lidx, suffix))

      detail_Path = params.VISUAL_PATH / '{:010d}'.format(lidx)
      detail_Path.mkdir(parents=True, exist_ok=True)
      for images, ilabel in zip(saved_images, saved_ilabel):
        batches = images.shape[0]
        for idx in range(batches):
          im = utils.norm2pil(images[idx])
          im.save('{}/{:010d}_{:03d}_{}.png'.format(detail_Path, lidx, idx, ilabel))
    else:
      raise RuntimeError('Unexpected global mode: %s' % gmode)
    self.logging_lt(msg)
Ejemplo n.º 3
0
    def apply(self, msg, runtime):
        '''
    Run one step in the queue
    '''
        gmode = params.GMODE
        mode = utils.msg_gt(msg, 'message_info|mode')
        heavy_logging = utils.msg_gt(msg, 'message_info|heavy_logging')

        if gmode == 'train' and mode == 'train':
            msg = self.train_once(msg, runtime)

            if heavy_logging:
                runtime.saver.save(runtime.sess,
                                   str(params.CKPT_PATH / 'model'),
                                   global_step=self.g_step,
                                   latest_filename=params.CKPT_FILE,
                                   meta_graph_suffix='meta',
                                   write_meta_graph=True,
                                   write_state=True)
        else:
            msg = self.inference(msg, runtime)
        return msg
Ejemplo n.º 4
0
    def apply(self, queue):
        # The image shape maybe updated before preprocess
        params.INROWS.value = params.INROWS.value
        params.INCOLS.value = params.INCOLS.value
        params.INCNLS.value = params.INCNLS.value

        iqueue = queue['generate']
        oqueue = queue[self.name]

        while not (params.SHOULD_FINISH.value == b'generate'
                   and iqueue.empty()):
            try:
                msg = iqueue.get(timeout=params.QUEUE_TIMEOUT)
            except Q.Empty:
                continue

            utils.msg_ud(msg, 'queue_info|prep', oqueue.qsize())

            mode = utils.msg_gt(msg, 'message_info|mode')
            if mode == 'train':
                covr, hide = utils.msg_gt(msg,
                                          'image|covr/train'), utils.msg_gt(
                                              msg, 'image|hide/train')
            elif mode == 'valid':
                covr, hide = utils.msg_gt(msg,
                                          'image|covr/valid'), utils.msg_gt(
                                              msg, 'image|hide/valid')
            else:
                raise RuntimeError('Invalid mode: %s' % mode)

            utils.msg_ud(msg, 'image|orig_covr', covr)
            utils.msg_ud(msg, 'image|orig_hide', hide)
            oqueue.put(msg)

        oqueue.close()
        oqueue.join_thread()
        params.SHOULD_FINISH.value = self.bname
        utils.eprint('preprocessor: exit')
Ejemplo n.º 5
0
  def __init__(self, message):
    self.message = message

    RT_META_PATH = params.LOGGING_PATH / params.RT_META_FILE
    if params.RESTART or not RT_META_PATH.exists():
      self.gidx = 0
      self.lidx = 0
      self.tidx = 0
      self.vidx = 0
    else:
      with RT_META_PATH.open('r') as meta_f:
        meta = json.load(meta_f)
      self.gidx = utils.msg_gt(meta, 'gidx') + 1
      self.lidx = 0
      TI, VI = params.TRAIN_INTERVAL, params.VALID_INTERVAL
      cycle = self.gidx // (TI + VI)
      offst = self.gidx % (TI + VI)
      self.tidx = cycle * TI + (TI if offst > TI else offst)
      self.vidx = cycle * params.VALID_INTERVAL + (offst if offst > TI else 0)
Ejemplo n.º 6
0
  def logging_lt(self, msg):
    mode = utils.msg_gt(msg, 'message_info|mode')
    if mode == 'train':
      self.train_last_msg = msg
    else:
      self.valid_last_msg = msg

    self.stdscr.erase()
    self.stdscr.addstr('Last Validation:\n')
    self.stdscr.addstr('  Validation:\n')
    if self.valid_last_msg:
      self.stdscr.addstr(self.log_one_msg(self.valid_last_msg, indent=2 * 2))
    self.stdscr.addstr('\n')

    self.stdscr.addstr('Last Train:\n')
    self.stdscr.addstr('  Train:\n')
    if self.train_last_msg:
      self.stdscr.addstr(self.log_one_msg(self.train_last_msg, indent=2 * 2))
    self.stdscr.addstr('\n')
    self.stdscr.refresh()
Ejemplo n.º 7
0
  def apply(self, queue):
    '''
    Main Logging Entry
    '''
    iqueue = queue['post']

    msg = None
    while not (params.SHOULD_FINISH.value == b'post' and iqueue.empty()):
      try:
        msg = iqueue.get(timeout=params.QUEUE_TIMEOUT)
      except Q.Empty:
        continue

      heavy_logging = utils.msg_gt(msg, 'message_info|heavy_logging')

      if heavy_logging:
        self.logging_hv(msg)
      else:
        self.logging_lt(msg)

    if msg:
      self.logging_hv(msg)
    params.SHOULD_FINISH.value = self.bname
    utils.eprint('logger: exit')
Ejemplo n.º 8
0
    def inference(self, msg, runtime):
        '''
    Inference model
    '''
        covr_img_v = utils.msg_gt(msg, 'image|orig_covr')
        hide_img_v = utils.msg_gt(msg, 'image|orig_hide')

        batch_size = params.BATCH_SIZE

        image_shape = [
            params.INROWS.value, params.INCOLS.value, params.INCNLS.value
        ]
        model_shape = [
            params.MNROWS.value, params.MNCOLS.value, params.MNCNLS.value
        ]
        mnrows, mncols, _ = image_shape
        inrows, incols, _ = model_shape

        slicer = utils.ImageSlicer(inrows, incols, mnrows, mncols)

        t_beg = timeit.default_timer()
        steg_img_v = np.zeros(shape=(batch_size, *image_shape))
        dcpt_img_v = np.zeros(shape=(batch_size, *image_shape))
        loss_va = []
        rcst_loss_va, rcst_vars_va, dcpt_loss_va, dcpt_vars_va = [], [], [], []
        for row_idx in range(inrows // mnrows):
            for col_idx in range(incols // mncols):
                # slice an image fragment at (row_idx, col_idx)
                covr_img_vs = slicer.slice(covr_img_v, row_idx, col_idx)
                hide_img_vs = slicer.slice(hide_img_v, row_idx, col_idx)

                loss_v, \
                rcst_loss_v, rcst_vars_v, dcpt_loss_v, dcpt_vars_v, \
                steg_img_vs, dcpt_img_vs = runtime.sess.run([
                    self.loss,
                    self.rcst_loss, self.rcst_vars, self.dcpt_loss, self.dcpt_vars,
                    self.steg_img, self.dcpt_img
                ], feed_dict={self.covr_img: covr_img_vs, self.hide_img: hide_img_vs})

                slicer.slice_assign(steg_img_v, row_idx, col_idx, steg_img_vs)
                slicer.slice_assign(dcpt_img_v, row_idx, col_idx, dcpt_img_vs)

                loss_va.append(loss_v)
                rcst_loss_va.append(rcst_loss_v)
                rcst_vars_va.append(rcst_vars_v)
                dcpt_loss_va.append(dcpt_loss_v)
                dcpt_vars_va.append(dcpt_vars_v)
        t_end = timeit.default_timer()
        t_diff = t_end - t_beg
        run_time = t_diff

        utils.msg_ud(msg, 'running|timing', run_time)
        utils.msg_ud(msg, 'running|train_cycle_timing', None)
        utils.msg_st(msg, 'image|steg', steg_img_v)
        utils.msg_st(msg, 'image|dcpt_covr', None)
        utils.msg_st(msg, 'image|dcpt_hide', dcpt_img_v)
        utils.msg_st(msg, 'post_info|loss', np.average(loss_va))
        utils.msg_st(msg, 'post_info|rcst_loss', np.average(rcst_loss_va))
        utils.msg_st(msg, 'post_info|rcst_vars', np.average(rcst_vars_va))
        utils.msg_st(msg, 'post_info|dcpt_loss', np.average(dcpt_loss_va))
        utils.msg_st(msg, 'post_info|dcpt_vars', np.average(dcpt_vars_va))

        return msg
Ejemplo n.º 9
0
    def train_once(self, msg, runtime):
        '''
    Train once
    '''
        utils.msg_ud(msg, 'running|task', 'train_once')

        covr_img_v = utils.msg_gt(msg, 'image|orig_covr')
        hide_img_v = utils.msg_gt(msg, 'image|orig_hide')

        gmode = params.GMODE
        mode = utils.msg_gt(msg, 'message_info|mode')
        heavy_logging = utils.msg_gt(msg, 'message_info|heavy_logging')
        batch_size = params.BATCH_SIZE

        image_shape = [
            params.INROWS.value, params.INCOLS.value, params.INCNLS.value
        ]
        model_shape = [
            params.MNROWS.value, params.MNCOLS.value, params.MNCNLS.value
        ]
        mnrows, mncols, _ = image_shape
        inrows, incols, _ = model_shape

        slicer = utils.ImageSlicer(inrows, incols, mnrows, mncols)

        t_beg = timeit.default_timer()
        steg_img_v = np.zeros(shape=(batch_size, *image_shape))
        dcpt_img_v = np.zeros(shape=(batch_size, *image_shape))
        loss_va = []
        rcst_loss_va, rcst_vars_va, dcpt_loss_va, dcpt_vars_va = [], [], [], []
        for row_idx in range(inrows // mnrows):
            for col_idx in range(incols // mncols):
                # slice an image fragment at (row_idx, col_idx)
                covr_img_vs = slicer.slice(covr_img_v, row_idx, col_idx)
                hide_img_vs = slicer.slice(hide_img_v, row_idx, col_idx)

                if gmode == 'train' and mode == 'train':
                    optm = self.optm
                else:
                    optm = self.dummy

                if heavy_logging:
                    smry = self.smry_hv
                else:
                    smry = self.smry_lt

                _, smry_v, \
                loss_v, \
                rcst_loss_v, rcst_vars_v, dcpt_loss_v, dcpt_vars_v, \
                steg_img_vs, dcpt_img_vs = runtime.sess.run([
                    optm, smry,
                    self.loss,
                    self.rcst_loss, self.rcst_vars, self.dcpt_loss, self.dcpt_vars,
                    self.steg_img, self.dcpt_img
                ], feed_dict={self.covr_img: covr_img_vs, self.hide_img: hide_img_vs})

                slicer.slice_assign(steg_img_v, row_idx, col_idx, steg_img_vs)
                slicer.slice_assign(dcpt_img_v, row_idx, col_idx, dcpt_img_vs)

                loss_va.append(loss_v)
                rcst_loss_va.append(rcst_loss_v)
                rcst_vars_va.append(rcst_vars_v)
                dcpt_loss_va.append(dcpt_loss_v)
                dcpt_vars_va.append(dcpt_vars_v)
        t_end = timeit.default_timer()
        t_diff = t_end - t_beg
        run_time = t_diff

        utils.msg_ud(msg, 'running|timing', run_time)
        utils.msg_ud(msg, 'running|train_cycle_timing', run_time)
        utils.msg_st(msg, 'image|steg', steg_img_v)
        utils.msg_st(msg, 'image|dcpt_covr', None)
        utils.msg_st(msg, 'image|dcpt_hide', dcpt_img_v)
        utils.msg_st(msg, 'post_info|loss', np.average(loss_va))
        utils.msg_st(msg, 'post_info|rcst_loss', np.average(rcst_loss_va))
        utils.msg_st(msg, 'post_info|rcst_vars', np.average(rcst_vars_va))
        utils.msg_st(msg, 'post_info|dcpt_loss', np.average(dcpt_loss_va))
        utils.msg_st(msg, 'post_info|dcpt_vars', np.average(dcpt_vars_va))

        if gmode == 'train' and mode == 'train':
            step_v = runtime.sess.run(self.g_step)
            runtime.summary_writer.add_summary(smry_v, step_v)
            runtime.sess.run(self.g_next_step)

        return msg