Exemplo n.º 1
0
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset)

    # VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                                 args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
            iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                             args.mask, args.refine, device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
Exemplo n.º 2
0
"""

import os
#import importlib
import sys
#importlib.reload(sys)
#sys.setdefaultencoding('utf8')
if sys.version[0] == '2':
    reload(sys)
    sys.setdefaultencoding("utf-8")
sys.path.append('..')
sys.path.append('../utils')
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import log_helper
log_helper.init_log('brc')

import pickle
import argparse
import logging
from dataset import BRCDataset
from vocab import Vocab
from rc_model import RCModel
from model_helper_v1 import ModelHelper
import gensim

logger = logging.getLogger("brc")

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
Exemplo n.º 3
0
#encoding: utf-8

import logging
from log_helper import init_log
init_log('brc')
logger = logging.getLogger('brc')
import pickle
import torch
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import os
#logger = logging.getLogger('global')

class BRCDataLoader(torch.utils.data.DataLoader):#torch.utils.data.DataLoader这个似乎是一个经常被使用的类,有必要细看一哈
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, pin_memory=False, drop_last=False):
        super(BRCDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler,
                                        num_workers, self._collate_fn, pin_memory, drop_last)
    def _collate_fn(self, batch):
        '''
        Args:
            batch: list of dict like
        one_piece = { 'question_token_ids': sample['question_token_ids'],
                      'question_length': len(sample['question_token_ids']),
                      'start_id': sample['answer_spans'][0][0] if 'answer_spans' in sample else None,
                      'end_id': sample['answer_spans'][0][1] if 'answer_spans' in sample else None,
                      'answer_passage_id': sample['answer_passages'][0] if 'answer_passages' in sample else None, 
                      'passage_token_ids': [[],[],[]],
                      'passage_length': [0,0,0]}
Exemplo n.º 4
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        model = Custom(anchors=cfg['anchors'])
    elif args.arch == 'Custom_Sky':
        model = Custom_Sky(anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # print(summary(model=model, input_size=(3, 511, 511), batch_size=1))
    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        print(args.resume)
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Exemplo n.º 5
0
def init_logs():
    log.init_log('main')
    log.init_log('database')
    log.init_log('logging')
    log.init_log('sql')
    log.init_log('gui')