def run(self):
        utils.mkdir_if_not_exist(self.outp_path)
        if not self.trimmo_parameters:
            self.trimmo_parameters = self.default_trimmo_parameters
        # cmd = 'java -jar {tool_path} PE -basein {fastq_file} -baseout {outp_full_path} ' \
        #       'ILLUMINACLIP:{adapter_full_path}:2:30:10 MAXINFO:{maxinfo} ' \
        #       'CROP:{crop} LEADING:{leading} ' \
        #       'TRAILING:{trailing} HEADCROP:{headcrop} ' \
        #       'SLIDINGWINDOW:{slidingwindow}'.format(tool_path=self.tool_path,
        #                          outp_full_path=self.outp_full_path,
        #                          adapter_full_path=self.adapter_full_path,
        #                          fastq_file=self.inp_data[0],
        #                          **self.trimmo_parameters
        #                          )
        cmd = 'java -jar {tool_path} PE -basein {fastq_file} -baseout {outp_full_path} ' \
            'ILLUMINACLIP:{adapter_full_path}:2:30:10 ' \
            'LEADING:{leading} ' \
            'TRAILING:{trailing} ' \
            'HEADCROP:{headcrop}'.format(tool_path=self.tool_path,
                                 outp_full_path=self.outp_full_path,
                                 adapter_full_path=self.adapter_full_path,
                                 fastq_file=self.inp_data[0],
                                 **self.trimmo_parameters
                                 )

        call(cmd, shell=True)
        with self.out_trimmo().open('w') as outfile:
            res = [
                trimmo_res for trimmo_res in glob.glob(self.outp_path + '*')
            ]
            json.dump(res, outfile)
Exemple #2
0
 def save(self, save_path):
     full_path = save_path + '/policy/'
     mkdir_if_not_exist(full_path)
     torch.save(self.qnet_target.state_dict(), full_path + 'model_q.pth')
     torch.save(self.optimizer_q.state_dict(), full_path + 'optim_q.pth')
     torch.save(self.hnet_target.state_dict(), full_path + 'model_h.pth')
     torch.save(self.optimizer_h.state_dict(), full_path + 'optim_h.pth')
Exemple #3
0
 def save(self, save_path):
     full_path = save_path + '/policy/'
     mkdir_if_not_exist(full_path)
     np.save(full_path + 'qtable.npy', self.qvalues)
     if self.intent:
         np.savez(full_path + 'htable.npz', dict(self.hvalues))
         np.savez(full_path + 'intent.npz', dict(self.intention))
Exemple #4
0
def resize_images(input_dir, output_dir, new_height, new_width):
    files = glob.glob(os.path.join(input_dir, '*.jpg'))
    utils.mkdir_if_not_exist(output_dir)
    for fi, full_file in enumerate(files):
        print('Resizing image {} of {}'.format(fi + 1, len(files)))
        im = Image.open(full_file)
        im_resized = im.resize((new_width, new_height), Image.ANTIALIAS)
        out_full_file = os.path.join(output_dir, os.path.basename(full_file))
        im_resized.save(out_full_file, 'JPEG', quality=90)
Exemple #5
0
 def load(self, load_path, env=None, **kwargs):
     self.policy.load(load_path)
     full_path = load_path + '/params/'
     mkdir_if_not_exist(full_path)
     with open(full_path + 'params.pkl', 'rb') as f:
         obj = pickle.load(f)
         for key, item in obj.items():
             try:
                 self.__dict__[key] = item
             except KeyError:
                 pass
 def run(self):
     utils.mkdir_if_not_exist(self.outp_path)
     cmd = 'java -jar {tool_path} PE -basein {fastq_file} -baseout {outp_full_path} ' \
           'ILLUMINACLIP:{adapter_full_path}:2:30:10 LEADING:3 TRAILING:3 SLIDINGWINDOW:4:32 ' \
           'MINLEN:36'.format(tool_path=self.tool_path,
                              outp_full_path=self.outp_full_path,
                              adapter_full_path=self.adapter_full_path,
                              fastq_file=self.inp_data[0]
                              )
     # logging.log.info("COMMAND TO EXECUTE: " + cmd)
     call(cmd, shell=True)
    def run(self):
        utils.mkdir_if_not_exist(self.outp_path)

        fastqs = [file for file in in_data()]
        cmd = '{tool_path} -o {out_path} {inp_1} {inp_2} {inp_3} {inp_4}'.format(
            tool_path=self.tool_path,
            out_path=self.outp_path,
            inp_1=fastqs[0],
            inp_2=fastqs[1],
            inp_3=fastqs[2],
            inp_4=fastqs[3])
        call(cmd, shell=True)
def apply_selective_search(manga109_path,
                           target_path,
                           threshold_pixels=2000,
                           threshold_ratio=1.8):
    manga_dirs = os.listdir(os.path.join(manga109_path, 'images'))
    for single_manga in manga_dirs:
        if single_manga.startswith('.'):
            continue
        single_manga_path = os.path.join(manga109_path, 'images', single_manga)
        # to make it possible that many programs can be running at the same time
        if os.path.exists(os.path.join(target_path, single_manga)):
            continue
        mkdir_if_not_exist(os.path.join(target_path, single_manga))
        pages = os.listdir(single_manga_path)
        for page in pages:
            if page.startswith('.'):
                continue

            page_num = page.split('.')[0]

            # loading astronaut image
            img = cv2.imread(os.path.join(single_manga_path, page))

            # perform selective search
            img_lbl, regions = selectivesearch.selective_search(img,
                                                                scale=500,
                                                                sigma=0.9,
                                                                min_size=10)

            candidates = set()
            for r in regions:
                # excluding same rectangle (with different segments)
                if r['rect'] in candidates:
                    continue
                # excluding regions smaller than 2000 pixels
                if r['size'] < threshold_pixels:
                    continue
                # distorted rects
                x, y, w, h = r['rect']
                if w / (h + 0.001) > threshold_ratio or h / (
                        w + 0.001) > threshold_ratio:
                    continue
                candidates.add(r['rect'])

            # save sub region files
            i = 0
            for x, y, w, h in candidates:
                cv2.imwrite(
                    os.path.join(target_path, single_manga,
                                 page_num + '_%04d.jpg' % i), img[y:y + h,
                                                                  x:x + w])
                i += 1
Exemple #9
0
def process_config(json_file):
    """
    解析Json文件
    :param json_file: 配置文件
    :return: 配置类
    """
    config, _ = get_config_from_json(json_file)
    config.tb_dir = os.path.join("experiments", config.exp_name, "logs/")  # 日志
    config.cp_dir = os.path.join("experiments", config.exp_name, "checkpoints/")  # 模型
    config.img_dir = os.path.join("experiments", config.exp_name, "images/")  # 网络

    mkdir_if_not_exist([config.tb_dir, config.cp_dir, config.img_dir])  # 创建文件夹
    return config
Exemple #10
0
 def run(self):
     sam_file = utils.deserialize(self.in_data().path, load)[0]
     utils.mkdir_if_not_exist(self.outp_path)
     with utils.cd(self.working_dir):
         cmds = ['{tool} view -bS {sam_file} > res.bam'.format(tool=self.tool_path, sam_file=sam_file),
                 '{tool} sort res.bam -o sorted.bam -O BAM'.format(tool=self.tool_path),
                 '{tool} index sorted.bam'.format(tool=self.tool_path),
                 # Make index fo reference
                 '{tool} faidx {ref_fasta}'.format(tool=self.tool_path, ref_fasta=self.ref_fasta)]
         for cmd in cmds:
             try:
                 run(cmd, shell=True, check=True)
             except CalledProcessError as e:
                 raise e('failed')
         utils.serialize(glob('*'), self.out_data().path, dump)
 def run(self):
     utils.mkdir_if_not_exist(self.outp_path)
     with open(self.in_data().path, 'r') as infile:
         inp_data = json.load(infile)
     fastqs = [file for file in inp_data]
     cmd = '{tool_path} -o {out_path} {inp_1} {inp_2} {inp_3} {inp_4}'.format(
         tool_path=self.tool_path,
         out_path=self.outp_path,
         inp_1=fastqs[0],
         inp_2=fastqs[1],
         inp_3=fastqs[2],
         inp_4=fastqs[3])
     call(cmd, shell=True)
     with self.out_fastq().open('w') as outfile:
         res = [fastq_res for fastq_res in glob.glob(self.outp_path + '*')]
         json.dump(res, outfile)
 def run(self):
     utils.mkdir_if_not_exist(self.outp_path)
     cmd = 'java -jar {tool_path} PE -basein {fastq_file} -baseout {outp_full_path} ' \
           'ILLUMINACLIP:{adapter_full_path}:2:30:10 LEADING:3 TRAILING:3 SLIDINGWINDOW:4:32 ' \
           'MINLEN:36'.format(tool_path=self.tool_path,
                              outp_full_path=self.outp_full_path,
                              adapter_full_path=self.adapter_full_path,
                              fastq_file=self.inp_data[0]
                              )
     # logging.log.info("COMMAND TO EXECUTE: " + cmd)
     call(cmd, shell=True)
     with self.out_trimmo().open('w') as outfile:
         res = [
             trimmo_res for trimmo_res in glob.glob(self.outp_path + '*')
         ]
         json.dump(res, outfile)
Exemple #13
0
def test_data_loading_and_preprocess():
  fig = plt.figure()
  ax = fig.add_subplot(111)

  def _visualize_example(save_path, image, gt_rboxes, mean_subtracted=True):
    ax.clear()
    # convert image
    image_display = vis.convert_image_for_visualization(
        image, mean_subtracted=mean_subtracted)
    # draw image
    ax.imshow(image_display)
    # draw groundtruths
    image_h = image_display.shape[0]
    image_w = image_display.shape[1]
    vis.visualize_rboxes(ax, gt_rboxes,
        edgecolor='yellow', facecolor='none', verbose=False)
    # save plot
    plt.savefig(save_path)

  n_batches = 10
  batch_size = 32

  save_dir = '../vis/example'
  utils.mkdir_if_not_exist(save_dir)

  streams = data.input_stream('../data/synthtext_train.tf')
  pstreams = data.train_preprocess(streams)
  batches = tf.train.shuffle_batch(pstreams, batch_size, capacity=2000, min_after_dequeue=20,
                                   num_threads=1)
  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)
    for i in xrange(n_batches):
      fetches = {'images': batches['image'],
                 'gt_rboxes': batches['rboxes'],
                 'gt_counts': batches['count']}
      sess_outputs = sess.run(fetches)
      for j in xrange(batch_size):
        save_path = os.path.join(save_dir, '%04d_%d.jpg' % (i, j))
        gt_count = sess_outputs['gt_counts'][j]
        _visualize_example(save_path,
                           sess_outputs['images'][j],
                           sess_outputs['gt_rboxes'][j, :gt_count],
                           mean_subtracted=True)
        print('Visualization saved to %s' % save_path)
Exemple #14
0
 def run(self):
     with utils.cd(self.working_dir):
         utils.mkdir_if_not_exist(self.outp_path)
         res = utils.deserialize(self.in_data().path, load)
         cmds = ['java -jar {tool} MarkDuplicates INPUT=sorted.bam OUTPUT=dedup_reads.bam ' \
               'METRICS_FILE=metrics.txt'.format(tool=self.tool_path),
         # Add readgroups info
         'java -jar {tool} AddOrReplaceReadGroups I=dedup_reads.bam ' \
                'O=dedup_reads_w_readgroups.bam RGID=4 RGLB=lib1 ' \
                'RGPL=illumina RGPU=unit1 RGSM=20'.format(tool=self.tool_path)]
         for cmd in cmds:
             print('Command', cmd)
             try:
                 run(cmd, shell=True, check=True)
             except CalledProcessError:
                 print('i failed')
             else:
                 utils.serialize(glob('*'), self.out_data().path, dump)
Exemple #15
0
    def run(self):
        with utils.cd(self.working_dir):
            utils.mkdir_if_not_exist(self.outp_path)
            res = utils.deserialize(self.in_data().path, load)

            # Analyze patterns of covariation in the sequence dataset
            # TODO: Добавить индексацию файла dedup_reads_w...
            cmds = ['java -jar {tool_path} -T BaseRecalibrator ' \
                   '-R {ref_fasta} -I dedup_reads_w_readgroups.bam ' \
                   '-knownSites {dbsnp} ' \
                   '-knownSites {indels} ' \
                   '-o recal_data.table'.format(tool_path=self.tool_path,
                                                ref_fasta=self.ref_fasta,
                                                dbsnp=self.dbsnp,
                                                indels=self.indels),
            # Do a second pass to analyze covariation remaining after recalibration
            'java -jar {tool_path} -T BaseRecalibrator ' \
                   '-R {ref_fasta} -I dedup_reads_w_readgroups.bam ' \
                   '-knownSites {dbsnp} ' \
                   '-knownSites {indels} ' \
                   '-BQSR recal_data.table -o post_recal_data.table'.format(tool_path=self.tool_path,
                                                                            ref_fasta=self.ref_fasta,
                                                                            dbsnp=self.dbsnp,
                                                                            indels=self.indels),
            # Generate before/after plots
            'java -jar {tool_path} -T AnalyzeCovariates -R {ref_fasta} ' \
                   '-L chr20 -before recal_data.table ' \
                   '-after post_recal_data.table ' \
                   '-plots recalibration_plots.pdf'.format(tool_path=self.tool_path,
                                                           ref_fasta=self.ref_fasta),
            # Apply the recalibration to your sequence data
            'java -jar {tool_path} -T PrintReads -R {ref_fasta} ' \
                   '-I dedup_reads_w_readgroups.bam ' \
                   '-BQSR recal_data.table -o recal_reads.bam'.format(tool_path=self.tool_path,
                                                                      ref_fasta=self.ref_fasta)]
            for cmd in cmds[2:]:
                print('Command', cmd)
                try:
                    run(cmd, shell=True, check=True)
                except CalledProcessError:
                    print('i failed')
                    raise
                else:
                    utils.serialize(glob('*'), self.out_data().path, dump)
Exemple #16
0
    def run(self):
        print(self.in_data)
        fastqs = utils.deserialize(self.in_data().path, load)
        # #print(self.deps())
        utils.mkdir_if_not_exist(self.outp_path)
        args_dict = OrderedDict()
        args_dict['-o'] = self.outp_path
        cmd = '{tool} {args} {inp1} {inp2}'.format(
            tool=self.tool_path,
            args=utils.make_args(args_dict),
            inp1=fastqs[0],
            inp2=fastqs[1])
        print(cmd)
        try:
            run(cmd, shell=True, check=True)
        except CalledProcessError as e:
            raise e('i failed')

        utils.serialize(glob(self.outp_path), self.out_fastq().path, dump)
 def run(self):
     outp_full_path = self.outp_path + 'trimmo'
     print('!!!', self.in_data().path)
     fastqs = utils.deserialize(self.in_data().path, load)
     utils.mkdir_if_not_exist(self.outp_path)
     if not self.trimmo_parameters:
         self.trimmo_parameters = self.default_trimmo_args
     self.trimmo_parameters['-basein'] = fastqs[0]
     self.trimmo_parameters['-baseout'] = outp_full_path
     cmd = 'java -jar {tool_path} {args}'.format(
         tool_path=self.tool_path,
         args=utils.make_args(self.trimmo_parameters))
     print('Command', cmd)
     try:
         run(cmd, shell=True, check=True)
     except CalledProcessError:
         print('i failed')
     else:
         utils.serialize(glob(self.outp_path + '*'),
                         self.out_trimmo().path, dump)
Exemple #18
0
def generate_augment(params):
    """
    音频增强
    :param params: 参数,[文件路径,音频ID,存储文件夹]
    :return: None
    """
    file_path, name_id = params

    folder = os.path.join(ROOT_DIR, 'data', 'npy_data')
    mkdir_if_not_exist(folder)

    try:
        saved_path = os.path.join(folder, name_id + '.npy')
        if os.path.exists(saved_path):
            print("[INFO] 文件 %s 存在" % name_id)
            return

        y_o, sr = librosa.load(file_path)
        y, _ = librosa.effects.trim(y_o, top_db=40)  # 去掉空白部分

        if not np.any(y):
            print('[Exception] 音频 %s 空' % name_id)
            return

        duration = len(y) / sr
        if duration < 4:  # 过滤小于3秒的音频
            print('[INFO] 音频 %s 过短: %0.4f' % (name_id, duration))
            return

        features = get_feature(y, sr)
        if check_error_features(features):
            print('[Exception] 音频 %s 错误' % name_id)
            return

        np.save(saved_path, features)  # 存储原文件的npy
    except Exception as e:
        print('[Exception] %s' % e)
        return

    print '[INFO] 音频ID ' + name_id
    return
Exemple #19
0
    def save(self, save_path, **kwargs):
        mkdir_if_not_exist(save_path)
        self.policy.save(save_path)
        if self.replay_buffer is not None:
            self.replay_buffer.save(save_path)

        excluded = []
        excluded = self.excluded_params()

        to_save = self.__dict__.copy()
        for key in excluded:
            if key in to_save:
                del to_save[key]

        # print(to_save)
        # breakpoint()
        full_path = save_path + '/params/'
        mkdir_if_not_exist(full_path)
        with open(full_path + 'params.pkl', 'wb') as f:
            # print(to_save)
            pickle.dump(to_save, f)
def process_config(json_file):
    """
    解析Json文件
    :param json_file: 配置文件
    :return: 配置类
    """
    config, _ = get_config_from_json(json_file)

    exp_dir = os.path.join(ROOT_DIR, "experiments")
    mkdir_if_not_exist(exp_dir)  # 创建文件夹

    exp_name = str(config.exp_name)
    print('[INFO] 工程名: %s' % exp_name)

    config.tb_dir = os.path.join(
        exp_dir, exp_name,
        "logs_%s" % time_2_readable(time.time(), fs='%Y%m%d%H%M%S'))  # 日志
    config.cp_dir = os.path.join(exp_dir, exp_name, "checkpoints")  # 模型
    config.img_dir = os.path.join(exp_dir, exp_name, "images")  # 网络

    mkdir_if_not_exist(config.tb_dir)  # 创建文件夹
    mkdir_if_not_exist(config.cp_dir)  # 创建文件夹
    mkdir_if_not_exist(config.img_dir)  # 创建文件夹
    return config
Exemple #21
0
 def __init__(self, output_path, x_images, y_images):
     self.x_images = x_images
     self.y_images = y_images
     self.output_path = output_path
     utils.mkdir_if_not_exist(self.output_path)
Exemple #22
0
def main():
    np.random.seed(0)
    utils.mkdir_if_not_exist(income_const['proc_dir'])
    preprocess_train_val()
    preprocess_test()
Exemple #23
0
def train_model(model_const, datasets, exp_const):
    torch.manual_seed(exp_const.seed)
    np.random.seed(exp_const.seed)

    utils.mkdir_if_not_exist(exp_const.exp_dir)
    utils.mkdir_if_not_exist(exp_const.log_dir)
    utils.mkdir_if_not_exist(exp_const.model_dir)

    print('Create tensorboard writer ...')
    tb_writer = SummaryWriter(log_dir=exp_const.log_dir)

    print('Creating model ...')
    model = IncomeClassifier(model_const)
    print(model)

    print('Creating dataloaders ...')
    dataloaders = {}
    dataloaders['train'] = DataLoader(datasets['train'],
                                      batch_size=exp_const.batch_size,
                                      shuffle=True,
                                      num_workers=exp_const.num_workers)
    dataloaders['val'] = DataLoader(datasets['val'],
                                    batch_size=exp_const.batch_size,
                                    shuffle=False,
                                    num_workers=exp_const.num_workers)

    print('Creating optimizer ...')
    opt = optim.Adam(model.parameters(),
                     lr=exp_const.lr,
                     weight_decay=exp_const.weight_decay)

    if exp_const.loss == 'cross_entropy':
        print('Cross Entropy Loss selected for training')
        criterion = nn.CrossEntropyLoss()
    elif exp_const.loss == 'focal':
        print('Focal Loss selected for training')
        criterion = FocalLoss(exp_const.gamma)
    else:
        assert (False), 'Requested loss not implemented'

    best_val_acc = 0
    step = 0
    for epoch in range(exp_const.num_epochs):
        for it, data in enumerate(dataloaders['train']):
            model.train()

            logits, probs = model(data['feat'])

            if exp_const.loss == 'cross_entropy':
                loss = criterion(logits, data['label'].long())
            elif exp_const.loss == 'focal':
                loss = criterion(probs, data['label'])
            else:
                assert (False), 'Requested loss not implemented'

            opt.zero_grad()
            loss.backward()
            opt.step()

            if step % 20 == 0:
                to_log = {
                    'Epoch': epoch,
                    'Iter': it,
                    'Step': step,
                    'Loss': round(loss.item(), 4),
                }
                log_str = '[train] '
                for k, v in to_log.items():
                    log_str += f'{k}: {v} | '

                print(log_str)
                tb_writer.add_scalar('Loss/TrainBatch', to_log['Loss'], step)

            if step % 100 == 0:
                print('-' * 100)
                print('Evaluation')
                with torch.no_grad():
                    train_loss, train_acc = validation(model,
                                                       dataloaders['train'],
                                                       exp_const, epoch, it,
                                                       step, 'train')
                    val_loss, val_acc = validation(model, dataloaders['val'],
                                                   exp_const, epoch, it, step,
                                                   'val')
                print('-' * 100)

                tb_writer.add_scalar('Loss/Train', train_loss, step)
                tb_writer.add_scalar('Loss/Val', val_loss, step)
                tb_writer.add_scalar('Accuracy/Train', train_acc, step)
                tb_writer.add_scalar('Accuracy/val', val_acc, step)

                if val_acc > best_val_acc:
                    to_save = {
                        'State': model.state_dict(),
                        'Accuracy': {
                            'val': val_acc,
                            'train': train_acc
                        },
                        'Iter': it,
                        'Step': step,
                        'Epoch': epoch
                    }

                    model_path = os.path.join(exp_const.model_dir,
                                              'best_model')
                    torch.save(to_save, model_path)

                    best_val_acc = val_acc

            step += 1

    tb_writer.close()
Exemple #24
0
 def save(self, save_path):
     full_path = save_path + '/buffer/'
     mkdir_if_not_exist(full_path)
     with open(full_path + 'buffer.pkl', 'wb') as f:
         pickle.dump(self.__dict__, f)
Exemple #25
0
import os
import cv2
import numpy as np
from utils import mkdir_if_not_exist
rootpath='/usr/guandai/birdnest400+/datasets/VOC2007/VOC2007/'
sourcename='JPEGImages_depth'
sourcepath=os.path.join(rootpath,sourcename)
targetpath = os.path.join(rootpath,sourcename+'_3channel')
mkdir_if_not_exist(targetpath)

for file in os.listdir(sourcepath):
    img = cv2.imread(os.path.join(sourcepath,file),-1)
    size = img.shape
    img3 = np.zeros((size[0],size[1],3),dtype=np.uint8)
    img3[:,:,0] = img
    img3[:,:,1] = img
    img3[:,:,2] = img
    cv2.imwrite(os.path.join(targetpath,file), img3)

Exemple #26
0
os.environ['data_root'] = data_root = './dataset/data'
os.environ[
    'im2rec'] = "python /home/wk/anaconda3/envs/gluon/lib/python3.6/site-packages/mxnet/tools/im2rec.py"

resize = (512, 512)
os.environ['resize'] = resize_str = str(resize[0]) + '_' + str(resize[1])
# for imgName in os.listdir(data_root+'/img'):
#     imgPath = data_root + '/img/' + imgName
#     img = Image.open(imgPath)
#     img = img.resize(resize,Image.BILINEAR)
#     print(data_root + '/' + 'img%d_%d'%(resize[0],resize[1]))
#     mkdir_if_not_exist(data_root + '/' + 'img%d_%d'%(resize[0],resize[1]))
#     img.save(data_root + '/'+'img%d_%d'%(resize[0],resize[1]) + '/' + imgName)

#generate lst
mkdir_if_not_exist(data_root + '/rec')
os.system(
    '$im2rec --list --train-ratio 0.9 ${data_root}/rec/img_$resize ${data_root}/img%d_%d'
    % (resize[0], resize[1]))

#modify lst
'''
lst格式:
    idx \t header_width \t label_width \t [labels] \t filename
    label_width为每个label的宽度。
    header_width为label之前idx之后的数据宽度,一般为2,指label_width 和 label_data两类
    label:class_idx \t xmin \t ymin \t xmax \t ymax (anchor数量)
'''

new_lst_content = ''
with open(data_root + '/rec/img_%s_train.lst' % (resize_str)) as f:
Exemple #27
0
import wget

from global_constants import income_const
import utils

utils.mkdir_if_not_exist(income_const['download_dir'])
for data_type in income_const['urls']:
    wget.download(income_const['urls'][data_type]['url'],
                  income_const['download_dir'])
Exemple #28
0
    orig_size = sess_outputs['orig_size'][j]
    save_path = os.path.join(result_dir, result_fname)
    with open(save_path, 'w') as f:
      lines = []
      for k in range(bboxes.shape[0]):
        bbox_str = list(bboxes[k])
        bbox_str = [str(o) for o in bbox_str]
        bbox_str = ','.join(bbox_str)
        lines.append(bbox_str)
      # remove duplicated lines
      lines = list(set(lines))
      f.write('\r\n'.join(lines))
      #logging.info('Detection results written to {}'.format(save_path))

    # save images and lexicon list for post-processing
    if FLAGS.save_image_and_lexicon:
      sess_outputs['']


if __name__ == '__main__':
  # create logging dir if not existed
  utils.mkdir_if_not_exist(FLAGS.log_dir)
  # set up logging
  log_file_name = FLAGS.log_prefix + time.strftime('%Y%m%d_%H%M%S') + '.log'
  log_file_path = os.path.join(FLAGS.log_dir, log_file_name)
  utils.setup_logger(log_file_path)
  utils.log_flags(FLAGS)
  #utils.log_git_version()
  # run test
  evaluate()
def evaluate(dataset_dir, out_img_dir, cuda_flag=True, savetxt='psnr.txt'):

    denoisenet = torch.load(model_path)
    if isinstance(denoisenet, torch.nn.DataParallel):
        denoisenet = denoisenet.module

    if cuda_flag:
        denoisenet.cuda().eval()
    else:
        denoisenet.eval()

    if SAVE_PNG:
        mkdir_if_not_exist(out_img_dir)
        mkdir_if_not_exist(out_img_dir + 'DTMC-HD/')
        mkdir_if_not_exist(out_img_dir + 'DTMC-HD/%d/' % sigma)

    test_img_list = os.listdir(dataset_dir)
    str_format = 'im%d.png'

    total_count = 483 + 7 * 6
    count = 0

    pre = datetime.datetime.now()
    psnr_val = 0
    ssim_val = 0

    for file in test_img_list:
        seq_noisy = dataset_dir + file + '/Gauss_%d/noisy/' % sigma
        num = len(os.listdir(seq_noisy))
        psnr = 0
        ssim = 0

        if SAVE_PNG:
            mkdir_if_not_exist(out_img_dir + 'DTMC-HD/%d/%s/' % (sigma, file))

        noisy_frames = []
        for j in range(1, num + 1):
            noisy_path = seq_noisy + 'im%d.png' % j
            noisy_frames.append(plt.imread(noisy_path))

        ######### pad ########
        h, w, c = noisy_frames[0].shape
        nh = (h // scale + 1) * scale
        noisy_frames = np.array(noisy_frames)
        noisy_frames_padded = np.lib.pad(noisy_frames,
                                         pad_width=((N // 2,
                                                     N // 2), ((nh - h) // 2,
                                                               (nh - h) // 2),
                                                    (0, 0), (0, 0)),
                                         mode='constant')
        noisy_frames_padded = np.transpose(noisy_frames_padded, (0, 3, 1, 2))

        for i in range(num):
            reference_path = dataset_dir + file + '/ori/' + str_format % (i +
                                                                          1)
            reference_frame = plt.imread(reference_path)

            input_frames = noisy_frames_padded[i:i + N]
            input_frames = torch.from_numpy(input_frames).cuda()

            input_frames = input_frames.view(1, input_frames.size(0),
                                             input_frames.size(1),
                                             input_frames.size(2),
                                             input_frames.size(3))

            x_list = denoisenet(input_frames)
            predicted_img = x_list[-1][0, :, (nh - h) // 2:-(nh - h) // 2]

            Img = predicted_img.permute(1, 2, 0).data.cpu().numpy().astype(
                np.float32)

            count += 1

            ######## compare PSNR and SSIM ########
            psnr += compare_psnr(Img, reference_frame, data_range=1.)
            ssim += compare_ssim(Img,
                                 reference_frame,
                                 data_range=1.,
                                 multichannel=True)

            ######## save output images ########
            if SAVE_PNG:
                plt.imsave(
                    out_img_dir + 'DTMC-HD/%d/%s/im%d.png' %
                    (sigma, file, i + 1), np.clip(Img, 0.0, 1.0))

            cur = datetime.datetime.now()
            processing_time = (cur - pre).seconds / count
            print('%.2fs per frame.\t%.2fs left.' %
                  (processing_time, processing_time * (total_count - count)))

        print('video %s, psnr %.4f, ssim %.4f.\n' %
              (file, psnr / num, ssim / num))
        psnr_val += psnr / num
        ssim_val += ssim / num
        # save loss.txt
        txtfile = open(savetxt, 'a')
        txtfile.write('video %s, psnr %.4f, ssim %.4f.\n' %
                      (file, psnr / num, ssim / num))
        txtfile.close()

    ave_psnr_val = psnr_val / len(test_img_list)
    ave_ssim_val = ssim_val / len(test_img_list)
    print('PSNR_val: %.4fdB, SSIM_val: %.4f.\n' % (ave_psnr_val, ave_ssim_val))
    txtfile = open(savetxt, 'a')
    txtfile.write('Average psnr %.4f, ssim %.4f.\n\n' %
                  (ave_psnr_val, ave_ssim_val))
    txtfile.close()
Exemple #30
0
def evaluate():
  with tf.device('/cpu:0'):
    # input data
    streams = data.input_stream(FLAGS.test_dataset)
    pstreams = data.test_preprocess(streams)
    if FLAGS.test_resize_method == 'dynamic':
      # each test image is resized to a different size
      # test batch size must be 1
      assert(FLAGS.test_batch_size == 1)
      batches = tf.train.batch(pstreams,
                               FLAGS.test_batch_size,
                               capacity=1000,
                               num_threads=1,
                               dynamic_pad=True)
    else:
      # resize every image to the same size
      batches = tf.train.batch(pstreams,
                               FLAGS.test_batch_size,
                               capacity=1000,
                               num_threads=1)
    image_size = tf.shape(batches['image'])[1:3]

  fetches = {}
  fetches['images'] = batches['image']
  fetches['image_name'] = batches['image_name']
  fetches['resize_size'] = batches['resize_size']
  fetches['orig_size'] = batches['orig_size']

  # detector
  detector = model.SegLinkDetector()
  all_maps = detector.build_model(batches['image'])

  # decode local predictions
  all_nodes, all_links, all_reg = [], [], []
  for i, maps in enumerate(all_maps):
    cls_maps, lnk_maps, reg_maps = maps
    reg_maps = tf.multiply(reg_maps, data.OFFSET_VARIANCE)

    # segments classification
    cls_prob = tf.nn.softmax(tf.reshape(cls_maps, [-1, 2]))
    cls_pos_prob = cls_prob[:, model.POS_LABEL]
    cls_pos_prob_maps = tf.reshape(cls_pos_prob, tf.shape(cls_maps)[:3])
    # node status is 1 where probability is higher than threshold
    node_labels = tf.cast(tf.greater_equal(cls_pos_prob_maps, FLAGS.node_threshold),
                          tf.int32)

    # link classification
    lnk_prob = tf.nn.softmax(tf.reshape(lnk_maps, [-1, 2]))
    lnk_pos_prob = lnk_prob[:, model.POS_LABEL]
    lnk_shape = tf.shape(lnk_maps)
    lnk_pos_prob_maps = tf.reshape(lnk_pos_prob,
                                   [lnk_shape[0], lnk_shape[1], lnk_shape[2], -1])
    # link status is 1 where probability is higher than threshold
    link_labels = tf.cast(tf.greater_equal(lnk_pos_prob_maps, FLAGS.link_threshold),
                          tf.int32)

    all_nodes.append(node_labels)
    all_links.append(link_labels)
    all_reg.append(reg_maps)

    fetches['link_labels_%d' % i] = link_labels

  # decode segments and links
  segments, group_indices, segment_counts = ops.decode_segments_links(
    image_size, all_nodes, all_links, all_reg,
    anchor_sizes=list(detector.anchor_sizes))
  fetches['segments'] = segments
  fetches['group_indices'] = group_indices
  fetches['segment_counts'] = segment_counts

  # combine segments
  combined_rboxes, combined_counts = ops.combine_segments(
    segments, group_indices, segment_counts)
  fetches['combined_rboxes'] = combined_rboxes
  fetches['combined_counts'] = combined_counts

  sess_config = tf.ConfigProto()
  with tf.Session(config=sess_config) as sess:
    # load model
    model_loader = tf.train.Saver()
    model_loader.restore(sess, FLAGS.test_model)

    batch_size = FLAGS.test_batch_size
    n_batches = int(math.ceil(FLAGS.num_test / batch_size))

    # result directory
    result_dir = os.path.join(FLAGS.log_dir, 'results' + FLAGS.result_suffix)
    utils.mkdir_if_not_exist(result_dir)

    intermediate_result_path = os.path.join(FLAGS.log_dir, 'intermediate.pkl')
    if FLAGS.load_intermediate:
      all_batches = joblib.load(intermediate_result_path)
      logging.info('Intermediate result loaded from {}'.format(intermediate_result_path))
    else:
      # run all batches and store results in a list
      all_batches = []
      with slim.queues.QueueRunners(sess):
        for i in range(n_batches):
          if i % 10 == 0:
            logging.info('Evaluating batch %d/%d' % (i+1, n_batches))
          sess_outputs = sess.run(fetches)
          all_batches.append(sess_outputs)
      if FLAGS.save_intermediate:
        joblib.dump(all_batches, intermediate_result_path, compress=5)
        logging.info('Intermediate result saved to {}'.format(intermediate_result_path))

    # # visualize local rboxes (TODO)
    # if FLAGS.save_vis:
    #   vis_save_prefix = os.path.join(save_dir, 'localpred_batch_%d_' % i)
    #   pred_rboxes_counts = []
    #   for j in range(len(all_maps)):
    #     pred_rboxes_counts.append((sess_outputs['segments_det_%d' % j],
    #                               sess_outputs['segment_counts_det_%d' % j]))
    #   _visualize_layer_det(sess_outputs['images'],
    #                       pred_rboxes_counts,
    #                       vis_save_prefix)

    # # visualize joined rboxes (TODO)
    # if FLAGS.save_vis:
    #   vis_save_prefix = os.path.join(save_dir, 'batch_%d_' % i)
    #   # _visualize_linked_det(sess_outputs, save_prefix)
    #   _visualize_combined_rboxes(sess_outputs, vis_save_prefix)

    if FLAGS.result_format == 'icdar_2015_inc':
      postprocess_and_write_results_ic15(all_batches, result_dir)
    elif FLAGS.result_format == 'icdar_2013':
      postprocess_and_write_results_ic13(all_batches, result_dir)
    else:
      logging.critical('Unknown result format: {}'.format(FLAGS.result_format))
      sys.exit(1)
  
  logging.info('Evaluation done.')