Exemplo n.º 1
0
    def path_message(self):
        """Path options values"""
        def add_message(part1, part2):
            left = " " * (round(TOTAL_LENGTH / 3) - len(part2 + ':'))
            right = getattr(part1, part2) + '\n'
            if TOTAL_LENGTH >= (len(left + part2 + ': ' + right) + 6):
                left = ">>>" + " " * (round(TOTAL_LENGTH / 3) -
                                      len(part2 + ':') - 3)
                right = getattr(part1, part2) + " " * (round(
                    TOTAL_LENGTH - 3 - len(left + part2 + ': ' +
                                           getattr(part1, part2)))) + "<<<\n"
            return left + part2 + ': ' + right

        message = ''
        equal_left, equal_right = cal_equal(len(self.mode) + 8)
        message += '\n' + ('=' * equal_left) + ' ' + self.mode + ' Paths ' + (
            '=' * equal_right) + '\n'
        for a in dir(self):
            if self.mode == 'Train':
                if not a.startswith("__") and not callable(getattr(self, a)) and \
                        a.split('_')[-1] == 'DIR' and 'RESULT' not in a and 'TEST' not in a:
                    message += add_message(self, a)
            elif self.mode == 'Test':
                if not a.startswith("__") and not callable(getattr(self, a)) and \
                        a.split('_')[-1] == 'DIR' and 'TRAIN' not in a and 'VAL' not in a:
                    message += add_message(self, a)
        equal_left, equal_right = cal_equal(5)
        message += '=' * equal_left + ' End ' + '=' * equal_right
        return message
Exemplo n.º 2
0
def print_test_info(cfg, step, metrics):
    """Print validate information on the screen.
    Inputs:
        cfg: training options
        step: a list includes --> [per_step, total_step]
        metrics: the current batch testing metrics
    """
    report = info = message = ""
    if step[0] == 0:
        equal_left, equal_right = cal_equal(15)
        info += "\n" + "=" * equal_left + " Start Testing " + "=" * equal_right
        print(info)
        if cfg.opts.save_test_log:
            wrote_txt_file(os.path.join(cfg.CHECKPOINT_DIR, 'Test_log.txt'),
                           info,
                           mode='w',
                           show=False)
    progress_bar(step[0], step[1], display=False)
    if cfg.opts.test_label != 'None':
        if step[0] + 1 >= step[1]:
            report, message = prepare_metrics(report, message, metrics)
            if cfg.opts.save_test_log:
                wrote_txt_file(os.path.join(cfg.CHECKPOINT_DIR,
                                            'Test_log.txt'),
                               message + "\n" + report,
                               mode='a',
                               show=False)
Exemplo n.º 3
0
def print_train_info(val_flag, cfg, epoch, step, loss, lr, metrics):
    """Print training epoch, step, loss, lr, and auc information on the screen.
    Inputs:
        val_flag: whether print the val info according to the epoch(bool)
        cfg: training options
        epoch: a list includes --> [start_epoch, per_epoch, total_epoch]
        step: a list includes --> [per_step, total_step]
        loss: a float includes --> train loss value
        lr: the current learning rate
        metrics: the current batch training metrics
    """
    message = ""
    # print '==== Start training ===='
    if epoch[1] == 1 and step[0] == 1:
        equal_left, equal_right = cal_equal(16)
        message += "\n" + "=" * equal_left + " Start Training " + "=" * equal_right
        # print '==== Start training ===='
    elif epoch[1] == epoch[0] and epoch[1] != 1 and step[0] == 1:
        equal_left, equal_right = cal_equal(19)
        message += "\n" + "=" * equal_left + " Continue Training " + "=" * equal_right
    # print '---- Epoch [1/10] ----'
    if (epoch[1] % cfg.opts.print_epoch == 0 or epoch[1] == epoch[0]
            or epoch[1] == epoch[2]) and step[0] - 1 == 0:
        val_flag = True
        info = " Epoch [{}/{}] ".format(epoch[1], epoch[2])
        equal_left, equal_right = cal_equal(len(info))
        message += "\n" + "-" * equal_left + info + "-" * equal_right
        message += "\n>>> Learning rate {:.7f}\n".format(lr)
    # print time, per_step, loss and acc
    if val_flag and (step[0] - 1 == 0 or step[0] == step[1]
                     or step[0] % cfg.opts.print_step == 0):
        current_time = time.strftime("%m-%d %H:%M:%S", time.localtime())
        message += "{}  Step:[{}/{}]".format(current_time, step[0], step[1])
        message += " " * (len(str(step[1])) -
                          len(str(step[0]))) + "  Loss:{:.4f}  ".format(loss)
        for metric in metrics:
            message += "{}:{:.3f}%  ".format(metric, metrics[metric] * 100)

    if len(message) != 0:
        print(message)
        if cfg.opts.save_train_log:
            mode = 'w' if epoch[1] == 1 and step[0] == 1 else 'a'
            wrote_txt_file(os.path.join(cfg.CHECKPOINT_DIR, 'Train_Log.txt'),
                           message,
                           mode=mode,
                           show=False)
    return val_flag
Exemplo n.º 4
0
 def param_message(self, args):
     """Parameter options values."""
     message = ''
     equal_left, equal_right = cal_equal(len(self.mode) + 10)
     message += '\n' + ('=' *
                        equal_left) + ' ' + self.mode + ' Options ' + (
                            '=' * equal_right) + '\n'
     for k, v in sorted(vars(args).items()):
         if 'checkpoint' in k:
             continue
         left = " " * (round(TOTAL_LENGTH / 2) - len(str(k) + " :"))
         right = " " + str(v) + "\n"
         if TOTAL_LENGTH >= (len(left + str(k) + " :" + right) + 6):
             left = ">>>" + " " * (round(TOTAL_LENGTH / 2) -
                                   len(str(k) + " :") - 3)
             right = " " + str(v) + " " * (round(TOTAL_LENGTH - 3 - len(
                 left + str(k) + " : " + str(v)))) + "<<<\n"
         message += left + str(k) + " :" + right
     equal_left, equal_right = cal_equal(5)
     message += '=' * equal_left + ' End ' + '=' * equal_right
     return message
Exemplo n.º 5
0
 def _display_network(self, verbose=False):
     """Print the total number of parameters and architecture
     in the network.
     Inputs:
         verbose: print the network architecture or not(bool)
     """
     equal_left, equal_right = cal_equal(22)
     print("\n" + "=" * equal_left + " Networks Initialized " +
           "=" * equal_right)
     num_params = 0
     for param in self.network.parameters():
         num_params += param.numel()
     if verbose:
         print(self.network)
     print('>>> [%s] Total size of parameters : %.3f M' %
           (self.opts.net_name, num_params / 1e6))
     print('>>> [%s] Weights initialize with : %s ' %
           (self.opts.net_name, self.opts.init_type))
     if self.cfg.mode == 'Train':
         print('>>> [%s] Learning rate scheduler : %s ' %
               (self.opts.net_name, self.opts.lr_scheduler))
     equal_left, equal_right = cal_equal(6)
     print('%s Done %s' % ('=' * equal_left, '=' * equal_right))
     print('>>> [%s] was created ...' % type(self).__name__)
Exemplo n.º 6
0
 def _display_network(self, verbose=False):
     """Print the total number of parameters and architecture
     in the network.
     Inputs:
         verbose: print the network architecture or not(bool)
     """
     equal_left, equal_right = cal_equal(22)
     print("\n" + "=" * equal_left + " Networks Initialized " +
           "=" * equal_right)
     if verbose:
         self.network.summary()
     if self.cfg.mode == 'Train':
         print('>>> [%s] Learning rate scheduler : %s ' %
               (self.opts.net_name, self.opts.lr_scheduler))
     print('>>> [%s] was created ...' % type(self).__name__)
Exemplo n.º 7
0
    def load_data(self, mode, shuffle=False):
        """Load the train or val or test dataset"""
        if mode == 'Train':
            label_name, label_dir, data_dir = [
                self.opts.train_label, self.cfg.TRAIN_LABEL_DIR,
                self.cfg.TRAIN_DATA_DIR
            ]
        elif mode == 'Val':
            label_name, label_dir, data_dir = [
                self.opts.val_label, self.cfg.VAL_LABEL_DIR,
                self.cfg.VAL_DATA_DIR
            ]
        else:
            label_name, label_dir, data_dir = [
                self.opts.test_label, self.cfg.TEST_LABEL_DIR,
                self.cfg.TEST_DATA_DIR
            ]

        if label_name == 'None':
            label_data = []
            test_names = os.listdir(data_dir)
            for name in test_names:
                label_data.append([name, 0])
        else:
            label_data = self._open_data_file(label_name, label_dir)
        if shuffle:
            random.shuffle(label_data)
        for index, data_set in enumerate(label_data):
            if mode == 'Test':
                length = self.opts.num_test if self.opts.num_test < len(
                    label_data) else len(label_data)
                if index + 1 > self.opts.num_test:
                    break
            else:
                length = len(label_data)
            progress_bar(index, length, "Loading {} dataset".format(mode))
            self._add_to_database(index, data_set, data_dir)

        equal_left, equal_right = cal_equal(6)
        print('\n%s Done %s' % ('=' * equal_left, '=' * equal_right))
Exemplo n.º 8
0
def print_test_info(cfg, step, loss=100.0, metric=0.0):
    """Print validate information on the screen.
    Inputs:
        cfg: training options
        step: a list includes --> [per_step, total_step]
        loss: a float includes --> val loss value
        metric: the current batch testing metric
    """
    info = message = ""
    if step[0] == 0:
        equal_left, equal_right = cal_equal(15)
        info += "\n" + "=" * equal_left + " Start Testing " + "=" * equal_right
        print(info)
        if cfg.opts.save_test_log:
            wrote_txt_file(os.path.join(cfg.CHECKPOINT_DIR, 'Test_log.txt'), info, mode='w', show=False)
    progress_bar(step[0], step[1], display=False)
    if cfg.opts.test_label != 'None':
        if step[0] + 1 >= step[1]:
            message += "\n>>> Loss:{:.4f}  ACC:{:.3f}%  ".format(loss / step[1], metric / step[1] * 100)
            print(message)
            if cfg.opts.save_test_log:
                wrote_txt_file(os.path.join(cfg.CHECKPOINT_DIR, 'Test_log.txt'), message, mode='a', show=False)