コード例 #1
0
ファイル: train.py プロジェクト: dvschultz/imaginaire
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)
    current_epoch, current_iteration = trainer.load_checkpoint(
        cfg, args.checkpoint,resume=args.resume)

    # Start training.
    for epoch in range(current_epoch, cfg.max_epoch):
        print('Epoch {} ...'.format(epoch))
        if not args.single_gpu:
            train_data_loader.sampler.set_epoch(current_epoch)
        trainer.start_of_epoch(current_epoch)
        for it, data in enumerate(train_data_loader):
            data = trainer.start_of_iteration(data, current_iteration)

            for _ in range(cfg.trainer.dis_step):
                trainer.dis_update(data)
            for _ in range(cfg.trainer.gen_step):
                trainer.gen_update(data)

            current_iteration += 1
            trainer.end_of_iteration(data, current_epoch, current_iteration)
            if current_iteration >= cfg.max_iter:
                print('Done with training!!!')
                return

        current_epoch += 1
        trainer.end_of_epoch(data, current_epoch, current_iteration)
    print('Done with training!!!')
    return
コード例 #2
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)
    if not hasattr(cfg, 'inference_args'):
        cfg.inference_args = None

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel.
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    test_data_loader = get_test_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          None, test_data_loader)

    # if args.checkpoint == '':
    #     # Download pretrained weights.
    #     pretrained_weight_url = cfg.pretrained_weight
    #     if pretrained_weight_url == '':
    #         print('google link to the pretrained weight is not specified.')
    #         raise
    #     default_checkpoint_path = args.config.replace('.yaml', '.pt')
    #     args.checkpoint = get_checkpoint(
    #         default_checkpoint_path, pretrained_weight_url)
    #     print('Checkpoint downloaded to', args.checkpoint)

    # Load checkpoint.
    trainer.load_checkpoint(cfg, args.checkpoint)

    # Do inference.
    trainer.current_epoch = -1
    trainer.current_iteration = -1
    trainer.test(test_data_loader, args.output_dir, cfg.inference_args)
コード例 #3
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)

    # Start evaluation.
    checkpoints = \
        sorted(glob.glob('{}/*.pt'.format(args.checkpoint_logdir)))
    for checkpoint in checkpoints:
        current_epoch, current_iteration = \
            trainer.load_checkpoint(cfg, checkpoint, resume=True)
        trainer.current_epoch = current_epoch
        trainer.current_iteration = current_iteration
        trainer.write_metrics()
    print('Done with evaluation!!!')
    return