예제 #1
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    cfg.work_dir = os.path.join(cfg.work_dir, args.stage)
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from

    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    # build model
    if args.stage == 'GMM':
        model = build_geometric_matching(cfg.GMM)
        print('Geometric Matching Module built')
        dataset = get_dataset(cfg.data.train.GMM)
        print('GMM dataset loaded')
        train_geometric_matching(model,
                                 dataset,
                                 cfg,
                                 distributed=distributed,
                                 validate=args.validate,
                                 logger=logger)
    elif args.stage == 'TOM':
        model = build_tryon(cfg.TOM)
        print('Try-On Module built')
        dataset = get_dataset(cfg.data.train.TOM)
        print('TOM dataset loaded')
        train_tryon(model,
                    dataset,
                    cfg,
                    distributed=distributed,
                    validate=args.validate,
                    logger=logger)
    else:
        raise ValueError('stage should be GMM or TOM')
예제 #2
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from

    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    # build model
    model = build_predictor(cfg.model)
    print('model built')

    if cfg.init_weights_from:
        model = init_weights_from(cfg.init_weights_from, model)

    # data loader
    dataset = get_dataset(cfg.data.train)
    print('dataset loaded')

    # train
    train_predictor(
        model,
        dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)
예제 #3
0
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir

    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # data loader
    test_dataset = get_dataset(cfg.data.test)
    print('dataset loaded')

    # build model and load checkpoint
    model = build_landmark_detector(cfg.model)
    print('model built')

    checkpoint = load_checkpoint(model, cfg.load_from, map_location='cpu')
    print('load checkpoint from: {}'.format(cfg.load_from))

    # test
    test_landmark_detector(
        model,
        test_dataset,
        cfg,
        distributed=distributed,
        validate=args.validate,
        logger=logger)