Пример #1
0
def main():
    rank, world_size = dist_init()
    cfg.merge_from_file(args.cfg)
    if rank == 0:
        if not os.path.exists(cfg.TRAIN.LOG_DIR):
            os.makedirs(cfg.TRAIN.LOG_DIR)
        init_log('global', logging.INFO)
        if cfg.TRAIN.LOG_DIR:
            add_file_handler('global',
                             os.path.join(cfg.TRAIN.LOG_DIR, 'logs.txt'),
                             logging.INFO)
        logger.info("Version Information: \n{}\n".format(commit()))
        logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    logger.info('dist init done!')
    train_dataloader = build_data_loader()
    model = get_model('BaseSiamModel').cuda().train()
    dist_model = DistModule(model)
    optimizer, lr_scheduler = build_optimizer_lr(dist_model.module,
                                                 cfg.TRAIN.START_EPOCH)
    if cfg.TRAIN.BACKBONE_PRETRAIN:
        logger.info('load backbone from {}.'.format(cfg.TRAIN.BACKBONE_PATH))
        model.backbone = load_pretrain(model.backbone, cfg.TRAIN.BACKBONE_PATH)
        logger.info('load backbone done!')
    if cfg.TRAIN.RESUME:
        logger.info('resume from {}'.format(cfg.TRAIN.RESUME_PATH))
        model, optimizer, cfg.TRAIN.START_EPOCH = restore_from(
            model, optimizer, cfg.TRAIN.RESUME_PATH)
        logger.info('resume done!')
    elif cfg.TRAIN.PRETRAIN:
        logger.info('load pretrain from {}.'.format(cfg.TRAIN.PRETRAIN_PATH))
        model = load_pretrain(model, cfg.TRAIN.PRETRAIN_PATH)
        logger.info('load pretrain done')
    dist_model = DistModule(model)
    train(train_dataloader, dist_model, optimizer, lr_scheduler)
Пример #2
0
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset)

    # VOS or VOT?
    if args.dataset in ['DAVIS','DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                                 args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
            iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                             args.mask, args.refine, device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
Пример #3
0
def main():
    cfg.merge_from_file(args.cfg)
    if not os.path.exists(cfg.PRUNING.FINETUNE.LOG_DIR):
        os.makedirs(cfg.PRUNING.FINETUNE.LOG_DIR)
    init_log('global', logging.INFO)
    if cfg.PRUNING.FINETUNE.LOG_DIR:
        add_file_handler(
            'global', os.path.join(cfg.PRUNING.FINETUNE.LOG_DIR, 'logs.txt'),
            logging.INFO)
    logger.info("Version Information: \n{}\n".format(commit()))
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    train_dataloader = build_data_loader()
    model = PruningSiamModel()
    # load model from the pruning model
    logger.info('load pretrain from {}.'.format(
        cfg.PRUNING.FINETUNE.PRETRAIN_PATH))
    model = load_pretrain(model, cfg.PRUNING.FINETUNE.PRETRAIN_PATH)
    logger.info('load pretrain done')
    logger.info('begin to pruning the model')
    model = prune_model(model).cuda().train()
    logger.info('pruning finished!')

    optimizer, lr_scheduler = build_optimizer_lr(
        model, cfg.PRUNING.FINETUNE.START_EPOCH)
    if cfg.PRUNING.FINETUNE.RESUME:
        logger.info('resume from {}'.format(cfg.PRUNING.FINETUNE.RESUME_PATH))
        model, optimizer, cfg.PRUNING.FINETUNE.START_EPOCH = restore_from(
            model, optimizer, cfg.PRUNING.FINETUNE.RESUME_PATH)
        logger.info('resume done!')
    train(train_dataloader, model, optimizer, lr_scheduler)
Пример #4
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()  # args通过解析获得的

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')  # 实例化一个记录器
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(
        cfg, indent=4)))  # 转变成json格式的文件,缩进4格

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()  # 模型转移到GPU上
    dist_model = torch.nn.DataParallel(
        model, list(range(torch.cuda.device_count()))).cuda()  # 多GPU训练

    if args.resume and args.start_epoch != 0:  # 这是在干啥?蒙蔽了!!!!!
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args,
                                           args.start_epoch)  # 如何构建优化器和学习策略???
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Пример #5
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO) # 返回一个logger对象,logging_INFO是日志的等级

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')  # 获取上面初始化的logger对象
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)  # 返回修改后的配置文件对象
    
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))  #json.loads()是将str转化成dict格式,json.dumps()是将dict转化成str格式。

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)  

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
Пример #6
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)
    else:
        raise Exception("Pretrained weights must be loaded!")

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    logger.info('model prepare done')

    logger = logging.getLogger('global')
    val_avg = AverageMeter()

    validation(val_loader, dist_model, cfg, val_avg)
Пример #7
0
def main():
    cfg.merge_from_file(args.cfg)
    if not os.path.exists(cfg.META.LOG_DIR):
        os.makedirs(cfg.META.LOG_DIR)
    init_log("global", logging.INFO)
    if cfg.META.LOG_DIR:
        add_file_handler("global", os.path.join(cfg.META.LOG_DIR, "logs.txt"),
                         logging.INFO)
    logger.info("Version Information: \n{}\n".format(commit()))
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    model = MetaSiamModel().cuda()
    model = load_pretrain(model, cfg.META.PRETRAIN_PATH)
    # init meta train
    model.meta_train_init()
    # parametes want to optim
    optimizer = build_optimizer(model)
    dataloader = build_dataloader()
    meta_train(dataloader, optimizer, model)
Пример #8
0
def main():
    cfg.merge_from_file(args.cfg)
    if not os.path.exists(cfg.GRAD.LOG_DIR):
        os.makedirs(cfg.GRAD.LOG_DIR)
    init_log("global", logging.INFO)
    if cfg.GRAD.LOG_DIR:
        add_file_handler("global", os.path.join(cfg.GRAD.LOG_DIR, "logs.txt"),
                         logging.INFO)
    logger.info("Version Information: \n{}\n".format(commit()))
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    model = get_model('GradSiamModel').cuda()
    model = load_pretrain(model, cfg.GRAD.PRETRAIN_PATH)
    # parametes want to optim
    optimizer = build_optimizer(model)
    dataloader = build_dataloader()
    if cfg.GRAD.RESUME:
        logger.info('resume from {}'.format(cfg.GRAD.RESUME_PATH))
        model, optimizer, cfg.GRAD.START_EPOCH = restore_from(
            model, optimizer, cfg.GRAD.RESUME_PATH)
        logger.info('resume done!')
    model.freeze_model()
    train(dataloader, optimizer, model)
Пример #9
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)
Пример #10
0
def main():
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    params = {'penalty_k': args.penalty_k,
              'window_influence': args.window_influence,
              'lr': args.lr,
              'instance_size': args.search_region}

    num_search = len(params['penalty_k']) * len(params['window_influence']) * \
        len(params['lr']) * len(params['instance_size'])

    print(params)
    print(num_search)
    cfg.merge_from_file(args.config)

#    cfg = load_config(args)
    model = ModelPublish()

    # load model
#    model = load_pretrain(model, args.resume).cuda().eval()

    tracker = build_tracker(model)

    # if args.resume:
    #     assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
    #     model = load_pretrain(model, args.resume)
    model.eval()
    model = model.to(device)
    tracker = build_tracker(model)

    default_hp =  {
        "seg_thr": 0.30,
        "penalty_k": 0.04,
        "window_influence": 0.42,
        "lr": 0.25
    }
    p = dict()

    p['network'] =tracker
    p['network_name'] = args.arch+'_'+args.resume.split('/')[-1].split('.')[0]
    p['dataset'] = args.dataset
    p['hp'] = default_hp.copy()
    s = p['hp'].values()
    print([float(x) for x in s])

    global ims, gt, image_files

    dataset_info = load_dataset(args.dataset)
    videos = list(dataset_info.keys())
    np.random.shuffle(videos)

    for video in videos:
        print(video)
        if isfile('finish.flag'):
            return

        p['video'] = video
        ims = None
        image_files = dataset_info[video]['image_files']
        gt = dataset_info[video]['gt']

        np.random.shuffle(params['penalty_k'])
        np.random.shuffle(params['window_influence'])
        np.random.shuffle(params['lr'])
        for penalty_k in params['penalty_k']:
            for window_influence in params['window_influence']:
                for lr in params['lr']:
                    for instance_size in params['instance_size']:
                        p['hp'] = default_hp.copy()
                        p['hp'].update({'penalty_k':penalty_k,
                                'window_influence':window_influence,
                                'lr':lr,
                                'instance_size': instance_size,
                                })
                        tune(p)
        print([float(x) for x in s])
Пример #11
0

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg',
                        default='',
                        type=str,
                        help='which config file to use')
    parser.add_argument('--snapshot',
                        default='',
                        type=str,
                        help='which model to pruning')
    args = parser.parse_args()
    cfg.merge_from_file(args.cfg)
    if not os.path.exists(cfg.PRUNING.LOG_DIR):
        os.makedirs(cfg.PRUNING.LOG_DIR)
    init_log('global', logging.INFO)
    if cfg.PRUNING.LOG_DIR:
        add_file_handler('global', os.path.join(cfg.PRUNING.LOG_DIR,
                                                'logs.txt'), logging.INFO)
    logger.info("Version Information: \n{}\n".format(commit()))
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    model = PruningSiamModel()
    model = load_pretrain(model, args.snapshot)

    for k, v in model.mask.items():
        print(k, v)
    model = prune_model(model)

    # torch.save(model.state_dict(), './snapshot/mobilenetv2_gdp/model_pruning.pth')
Пример #12
0
def main():
    global xent_criterion, triplet_criterion, ment_criterion

    logger.info("init done")

    if os.path.exists(cfg.TRAIN.LOG_DIR):
        shutil.rmtree(cfg.TRAIN.LOG_DIR)
    os.makedirs(cfg.TRAIN.LOG_DIR)
    init_log('global', logging.INFO)
    if cfg.TRAIN.LOG_DIR:
        add_file_handler('global', os.path.join(cfg.TRAIN.LOG_DIR, 'logs.txt'),
                         logging.INFO)

    dataset, train_loader, _, _ = build_data_loader()
    model = BagReID_IBN(dataset.num_train_pids, dataset.num_train_mates)
    xent_criterion = CrossEntropyLabelSmooth(dataset.num_train_pids)
    triplet_criterion = TripletLoss(margin=cfg.TRAIN.TRI_MARGIN)
    ment_criterion = CrossEntropyMate(cfg.TRAIN.MATE_LOSS_WEIGHT)
    if cfg.TRAIN.OPTIM == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.SOLVER.LEARNING_RATE,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=cfg.SOLVER.LEARNING_RATE,
                                     weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    optimizers = [optimizer]
    schedulers = build_lr_schedulers(optimizers)

    if cfg.CUDA:
        model.cuda()
        if torch.cuda.device_count() > 1:
            model = DataParallel(model)

    if cfg.TRAIN.LOG_DIR:
        summary_writer = SummaryWriter(cfg.TRAIN.LOG_DIR)
    else:
        summary_writer = None

    logger.info("model prepare done")
    start_epoch = cfg.TRAIN.START_EPOCH
    # start training
    for epoch in range(start_epoch, cfg.TRAIN.NUM_EPOCHS):
        train(epoch, train_loader, model, criterion, optimizers,
              summary_writer)
        for scheduler in schedulers:
            scheduler.step()
        # skip if not save model
        if cfg.TRAIN.EVAL_STEP > 0 and (epoch + 1) % cfg.TRAIN.EVAL_STEP == 0 \
                or (epoch + 1) == cfg.TRAIN.NUM_EPOCHS:

            if cfg.CUDA and torch.cuda.device_count() > 1:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint({
                'state_dict': state_dict,
                'epoch': epoch + 1
            },
                            is_best=False,
                            save_dir=cfg.TRAIN.SNAPSHOT_DIR,
                            filename='checkpoint_ep' + str(epoch + 1) +
                            '.pth.tar')
Пример #13
0
def main():
    # 获取命令行参数信息
    global args, logger, v_id
    args = parser.parse_args()
    # 获取配置文件中配置信息:主要包括网络结构,超参数等
    cfg = load_config(args)
    # 初始化logxi信息,并将日志信息输入到磁盘文件中
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 将相关的配置信息输入到日志文件中
    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    # 加载网络模型架构
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))
    # 加载网络模型参数
    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(model, args.resume)
    # 使用评估模式,将drop等激活
    model.eval()
    # 硬件信息
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # 加载数据集 setup dataset
    dataset = load_dataset(args.dataset)

    # 这三种数据支持掩膜 VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []
    # 对数据进行处理
    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue
        # true 调用track_vos
        if vos_enable:
            # 如测试数据是['DAVIS2017', 'ytb_vos']时,会开启多目标跟踪
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            iou_lists.append(iou_list)
        # False 调用track_vot
        else:
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(
                thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
Пример #14
0
def main():
    global args, device, max_acc, writer

    max_acc = -1
    args = parser.parse_args()
    if args.arch == 'SharpMask':
        trainSm = True
        args.hfreq = 1
        args.gSz = args.iSz
    else:
        trainSm = False

    # Setup experiments results path
    pathsv = 'sharpmask/train' if trainSm else 'deepmask/train'
    args.rundir = join(args.rundir, pathsv)
    try:
        if not isdir(args.rundir):
            makedirs(args.rundir)
    except OSError as err:
        print(err)

    # Setup logger
    init_log('global', logging.INFO)
    add_file_handler('global', join(args.rundir, 'train.log'), logging.INFO)
    logger = logging.getLogger('global')
    logger.info('running in directory %s' % args.rundir)
    logger.info(args)
    writer = SummaryWriter(log_dir=join(args.rundir, 'tb'))

    # Get argument defaults (hastag #thisisahack)
    parser.add_argument('--IGNORE', action='store_true')
    defaults = vars(parser.parse_args(['--IGNORE']))

    # Print all arguments, color the non-defaults
    for argument, value in sorted(vars(args).items()):
        reset = colorama.Style.RESET_ALL
        color = reset if value == defaults[argument] else colorama.Fore.MAGENTA
        logger.info('{}{}: {}{}'.format(color, argument, value, reset))

    # Setup seeds
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Setup Model
    model = (models.__dict__[args.arch](args)).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    logger.info(model)

    # Setup data loader
    train_dataset = get_loader(args.dataset)(args, split='train')
    val_dataset = get_loader(args.dataset)(args, split='val')
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=None)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch,
                                 num_workers=args.workers,
                                 pin_memory=True,
                                 sampler=None)

    # Setup Metrics
    criterion = nn.SoftMarginLoss().to(device)

    # Setup optimizer, lr_scheduler and loss function
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer, milestones=[50, 120], gamma=0.3)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            max_acc = checkpoint['max_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.warning("no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    for epoch in range(args.start_epoch, args.maxepoch):
        scheduler.step(epoch=epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        if epoch % 2 == 1:
            acc = validate(val_loader, model, criterion, epoch)

            is_best = acc > max_acc
            max_acc = max(acc, max_acc)
            # remember best mean loss and save checkpoint
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'max_acc': max_acc,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.rundir)
Пример #15
0
def main():
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    params = {'penalty_k': args.penalty_k,
              'window_influence': args.window_influence,
              'lr': args.lr,
              'instance_size': args.search_region}

    num_search = len(params['penalty_k']) * len(params['window_influence']) * \
        len(params['lr']) * len(params['instance_size'])

    print(params)
    print(num_search)

    cfg = load_config(args)
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        model = models.__dict__[args.arch](anchors=cfg['anchors'])

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    model = model.to(device)

    default_hp = cfg.get('hp', {})

    p = dict()

    p['network'] = model
    p['network_name'] = args.arch+'_'+args.resume.split('/')[-1].split('.')[0]
    p['dataset'] = args.dataset

    global ims, gt, image_files

    dataset_info = load_dataset(args.dataset)
    videos = list(dataset_info.keys())
    np.random.shuffle(videos)

    for video in videos:
        print(video)
        if isfile('finish.flag'):
            return

        p['video'] = video
        ims = None
        image_files = dataset_info[video]['image_files']
        gt = dataset_info[video]['gt']

        np.random.shuffle(params['penalty_k'])
        np.random.shuffle(params['window_influence'])
        np.random.shuffle(params['lr'])
        for penalty_k in params['penalty_k']:
            for window_influence in params['window_influence']:
                for lr in params['lr']:
                    for instance_size in params['instance_size']:
                        p['hp'] = default_hp.copy()
                        p['hp'].update({'penalty_k':penalty_k,
                                'window_influence':window_influence,
                                'lr':lr,
                                'instance_size': instance_size,
                                })
                        tune(p)
Пример #16
0
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset, args.dir_type)

    # VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    # total_lost = 0  # VOT
    # iou_lists = []  # VOS
    # speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            # iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
Пример #17
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    
    print("Init logger")

    logger = logging.getLogger('global')

    print(44)
    #logger.info("\n" + collect_env_info())
    print(99)
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    print(2)

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    print(3)

    path = "/usr4/alg504/cliao25/siammask/experiments/siammask_base/snapshot/checkpoint_e{}.pth"

    for epoch in range(1,21):

        if args.arch == 'Custom':
            from custom import Custom
            model = Custom(pretrain=True, anchors=cfg['anchors'])
        else:
            exit()

        print(4)

        if args.pretrained:
            model = load_pretrain(model, args.pretrained)

        model = model.cuda()


        #model.features.unfix((epoch - 1) / 20)
        optimizer, lr_scheduler = build_opt_lr(model, cfg, args, epoch)
        filepath = path.format(epoch)
        assert os.path.isfile(filepath)

        model, _, _, _, _ = restore_from(model, optimizer, filepath)
        #model = load_pretrain(model, filepath)
        model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

        model.train()
        device = torch.device('cuda')
        model = model.to(device)

        valid(val_loader, model, cfg)

    print("Done")
Пример #18
0
experiment_path = cfg.meta["experiment_path"]
experiment_name = cfg.meta["experiment_name"]
arch = cfg.meta["arch"]
# 训练时候的一些参数
batch_size = cfg.train['batch_size']
epoches = cfg.train['epoches']
lr = cfg.train['lr']
# 初始化未来帧的数量
num_frame = cfg.model['input_num']
# print freq
print_freq = cfg.train['print_freq']

# 初始化logger
global_logger = init_log('global', level=logging.INFO)
add_file_handler("global",
                 os.path.join(os.getcwd(), 'logs',
                              '{}.log'.format(experiment_name)),
                 level=logging.DEBUG)

# 打印cfg信息
cfg.log_dict()

# 初始化avrager
avg = AverageMeter()

# cuda
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# 准备数据集
train_set = MovingMNIST(root='./data/mnist',
Пример #19
0
def main(args):
    cfg_from_file(args.config)
    cfg.save_name = args.save_name
    cfg.save_path = args.save_path
    cfg.resume_file = args.resume_file
    cfg.config = args.config
    cfg.batch_size = args.batch_size
    cfg.num_workers = args.num_workers
    save_path = join(args.save_path, args.save_name)
    if not exists(save_path):
        makedirs(save_path)
    resume_file = args.resume_file
    init_log('global', logging.INFO)
    add_file_handler('global', os.path.join(save_path, 'logs.txt'),
                     logging.INFO)
    logger.info("Version Information: \n{}\n".format(commit()))
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    start_epoch = 0

    model = ModelBuilder().cuda()
    if cfg.backbone.pretrained:
        load_pretrain(model.backbone,
                      join('pretrained_net', cfg.backbone.pretrained))

    train_dataset = Datasets()
    val_dataset = Datasets(is_train=False)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers,
                                               pin_memory=False,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=False,
                                             drop_last=True)

    if resume_file:
        if isfile(resume_file):
            logger.info("=> loading checkpoint '{}'".format(resume_file))
            model, start_epoch = restore_from(model, resume_file)
            start_epoch = start_epoch + 1
            for i in range(start_epoch):
                train_loader.dataset.shuffle()
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                resume_file, start_epoch - 1))
        else:
            logger.info("=> no checkpoint found at '{}'".format(resume_file))

    ngpus = torch.cuda.device_count()
    is_dataparallel = False
    if ngpus > 1:
        model = torch.nn.DataParallel(model, list(range(ngpus))).cuda()
        is_dataparallel = True

    if is_dataparallel:
        optimizer, lr_scheduler = build_opt_lr(model.module, start_epoch)
    else:
        optimizer, lr_scheduler = build_opt_lr(model, start_epoch)

    logger.info(lr_scheduler)
    logger.info("model prepare done")

    if args.log:
        writer = SummaryWriter(comment=args.save_name)

    for epoch in range(start_epoch, cfg.train.epoch):
        train_loader.dataset.shuffle()
        if (epoch == np.array(cfg.backbone.unfix_steps)
            ).sum() > 0 or epoch == cfg.train.pretrain_epoch:
            if is_dataparallel:
                optimizer, lr_scheduler = build_opt_lr(model.module, epoch)
            else:
                optimizer, lr_scheduler = build_opt_lr(model, epoch)
        lr_scheduler.step(epoch)
        record_dict_train = train(train_loader, model, optimizer, epoch)
        record_dict_val = validate(val_loader, model, epoch)
        message = 'Train Epoch: [{0}]\t'.format(epoch)
        for k, v in record_dict_train.items():
            message += '{name:s} {loss:.4f}\t'.format(name=k, loss=v)
        logger.info(message)
        message = 'Val Epoch: [{0}]\t'.format(epoch)
        for k, v in record_dict_val.items():
            message += '{name:s} {loss:.4f}\t'.format(name=k, loss=v)
        logger.info(message)

        if args.log:
            for k, v in record_dict_train.items():
                writer.add_scalar('train/' + k, v, epoch)
            for k, v in record_dict_val.items():
                writer.add_scalar('val/' + k, v, epoch)
        if is_dataparallel:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'cfg': cfg
                }, epoch, save_path)
        else:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'cfg': cfg
                }, epoch, save_path)
Пример #20
0
    # Create experiment folder structure
    curernt_file_path = os.path.dirname(os.path.abspath(__file__))
    experiment_folder = os.path.join(curernt_file_path, "all_experiment",
                                     experiment_name)
    experiment_snap_folder = os.path.join(experiment_folder, "snap")
    experiment_board_folder = os.path.join(experiment_folder, "board_validate")
    os.makedirs(experiment_folder, exist_ok=True)
    os.makedirs(experiment_snap_folder, exist_ok=True)
    os.makedirs(experiment_board_folder, exist_ok=True)

    # init board writer
    writer = SummaryWriter(experiment_board_folder)

    # get log
    add_file_handler("global",
                     os.path.join(experiment_folder, 'validate.log'),
                     level=logging.INFO)

    # get dataset
    dataset_pre_processing = []
    train_dataloader, test_dataloader = get_train_dataloader(
        cfg=cfg, use_cuda=cuda, pre_process_transform=dataset_pre_processing)

    # load model
    model, start_epoch = load_model_test(cfg, cuda)

    # get evaluation loss
    loss = get_loss(cfg)

    # start validte model
    validation(start_epoch, log_interval, test_dataloader, model, loss, writer,
Пример #21
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')
    global cur_lr

    if not os.path.exists(args.save_dir):  # makedir/save model
        os.makedirs(args.save_dir)
    num_per_epoch = len(train_loader.dataset) // args.batch
    num_per_epoch_val = len(val_loader.dataset) // args.batch

    for epoch in range(args.start_epoch, args.epochs):
        lr_scheduler.step(epoch)
        cur_lr = lr_scheduler.get_cur_lr()
        logger = logging.getLogger('global')
        train_avg = AverageMeter()
        val_avg = AverageMeter()

        if dist_model.module.features.unfix(epoch / args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg,
                                                   args, epoch)

        train(train_loader, dist_model, optimizer, lr_scheduler, epoch, cfg,
              train_avg, num_per_epoch)

        if dist_model.module.features.unfix(epoch / args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg,
                                                   args, epoch)

        if (epoch + 1) % args.save_freq == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': dist_model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            validation(val_loader, dist_model, epoch, cfg, val_avg,
                       num_per_epoch_val)
Пример #22
0
def main():
    global args, logger, v_id  #全局变量
    args = parser.parse_args()  #args是test.py文件运行时,接受的参数
    cfg = load_config(args)  #加载 JSON 配置文件并设置args.arch的值。
    print(cfg)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log,
                         logging.INFO)  #add_file_handler 创建一个记录器并绑定文件句柄。

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model         Custom 为论文实现的网络。如果不是“Custom”,加载 models 下指定的结构。
    if args.arch == 'Custom':  #args.arch参数,预训练模型的结构,命令行不给的话,默认为' ',
        from custom import Custom
        model = Custom(anchors=cfg['anchors']
                       )  #cfg是从config_vot.json的到的数据,所以跟踪时用的model.anchors字典中的数据
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:  #给了args.resume,如果args.resume不是文件,报错,
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(
            model, args.resume)  #args.resume是文件load_pretrain ,能够处理网络之间的不一致
    model.eval()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)

    # setup dataset,字典
    dataset = load_dataset(
        args.dataset)  #load_dataset 能够加载 VOT、DAVIS、ytb_vos 三种数据集。
    #仅以上三种数据源支持掩膜输出。

    # VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output  ,使用掩膜输出
    else:
        vos_enable = False

    total_lost = 0  # VOT  跟踪任务有损失函数
    iou_lists = []  # VOS  分割任务
    speed_list = []

    #v_id视频索引从1起,video是视频名字
    for v_id, video in enumerate(dataset.keys(), start=1):
        if v_id == 2:
            exit()
        if args.video != '' and video != args.video:  #不成立,args.video默认是' '
            continue

        if vos_enable:  #分割任务,,,,分割任务和跟踪任务只能选一个
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            iou_lists.append(iou_list)  #iou_list是什么类型的数据???
        else:  #跟踪任务
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result记录最终结果
    if vos_enable:  #如果进行的是分割任务
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(
                thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
Пример #23
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    cfg = load_config(args)

    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    
    logger.info("\n" + collect_env_info())

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        model = models.__dict__[args.arch](anchors=cfg['anchors'])

    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    logger.info(lr_scheduler)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
        epoch = args.start_epoch
        if dist_model.module.features.unfix(epoch/args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg, args, epoch)
        lr_scheduler.step(epoch)
        cur_lr = lr_scheduler.get_cur_lr()
        logger.info('epoch:{} resume lr {}'.format(epoch, cur_lr))

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
Пример #24
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Пример #25
0
def main():
    logger = logging.getLogger('global')
    global criterion_xent, criterion_triplet, criterion_center
    if os.path.exists(cfg.TRAIN.LOG_DIR):
        shutil.rmtree(cfg.TRAIN.LOG_DIR)
    os.makedirs(cfg.TRAIN.LOG_DIR)
    init_log('global', logging.INFO)  # log
    add_file_handler('global', os.path.join(cfg.TRAIN.LOG_DIR, 'logs.txt'),
                     logging.INFO)
    summary_writer = SummaryWriter(cfg.TRAIN.LOG_DIR)  # visualise

    dataset, train_loader, _, _ = build_data_loader()
    model = BagReID_RESNET(dataset.num_train_bags)
    criterion_xent = CrossEntropyLabelSmooth(dataset.num_train_bags,
                                             use_gpu=cfg.CUDA)
    criterion_triplet = TripletLoss(margin=cfg.TRAIN.MARGIN)
    criterion_center = CenterLoss(dataset.num_train_bags,
                                  cfg.MODEL.GLOBAL_FEATS +
                                  cfg.MODEL.PART_FEATS,
                                  use_gpu=cfg.CUDA)
    if cfg.TRAIN.OPTIM == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.SOLVER.LEARNING_RATE,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=cfg.SOLVER.LEARNING_RATE,
                                     weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    center_optimizer = torch.optim.SGD(criterion_center.parameters(),
                                       lr=cfg.SOLVER.LEARNING_RATE_CENTER)

    optimizers = [optimizer, center_optimizer]
    schedulers = build_lr_schedulers(optimizers)

    if cfg.CUDA:
        model.cuda()
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model, device_ids=cfg.DEVICES)

    logger.info("model prepare done")
    # start training
    for epoch in range(cfg.TRAIN.NUM_EPOCHS):
        train(epoch, train_loader, model, criterion, optimizers,
              summary_writer)
        for scheduler in schedulers:
            scheduler.step()

        # skip if not save model
        if cfg.TRAIN.EVAL_STEP > 0 and (epoch + 1) % cfg.TRAIN.EVAL_STEP == 0 \
                or (epoch + 1) == cfg.TRAIN.NUM_EPOCHS:

            if cfg.CUDA and torch.cuda.device_count() > 1:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint({
                'state_dict': state_dict,
                'epoch': epoch + 1
            },
                            is_best=False,
                            save_dir=cfg.TRAIN.SNAPSHOT_DIR,
                            filename='checkpoint_ep' + str(epoch + 1) + '.pth')
Пример #26
0
def main():
    """
    基础网络的训练
    :return:
    """
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    # 初始化日志信息
    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 获取log信息
    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)
    # 获取配置信息
    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # 构建数据集
    train_loader, val_loader = build_data_loader(cfg)
    # 加载训练网络
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)
    # 加载预训练网络
    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # GPU版本
    # model = model.cuda()
    # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
    # 网络模型
    dist_model = torch.nn.DataParallel(model)
    # 模型参数的更新比例
    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)
    # 获取优化器和学习率的更新策略
    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint 加载模型
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        # GPU
        # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
        dist_model = torch.nn.DataParallel(model)

    logger.info(lr_scheduler)

    logger.info('model prepare done')
    # 模型训练
    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Пример #27
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        model = Custom(anchors=cfg['anchors'])
    elif args.arch == 'Custom_Sky':
        model = Custom_Sky(anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # print(summary(model=model, input_size=(3, 511, 511), batch_size=1))
    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        print(args.resume)
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Пример #28
0
import argparse

import logging
from utils.log_helper import init_log, add_file_handler, print_speed
from utils.config_helper import Configs
from utils.average_meter_helper import AverageMeter

# 生成命令行的参数
parser = argparse.ArgumentParser(description='Train moving mnist video prediction algorithm')
parser.add_argument('-c', '--cfg', default=os.path.join(os.getcwd(), "tools", "train_config.json"), type=str, required=False, help='training config file path')

args = parser.parse_args()

# 初始化logger
global_logger = init_log('global', level=logging.INFO)
add_file_handler("global", os.path.join(os.getcwd(), 'logs', 'ex_1.log'), level=logging.DEBUG)

# 初始化一些变量
cfg = Configs(args.cfg)
# board的路径
board_path = cfg.meta["board_path"]
experiment_path = cfg.meta["experiment_path"]
arch = cfg.meta["arch"]
# 训练时候的一些参数
batch_size = cfg.train['batch_size']
epoches = cfg.train['epoches']
lr = cfg.train['lr']
# 初始化未来帧的数量
input_num = cfg.model['input_num']
# print freq
print_freq = cfg.train['print_freq']