Example #1
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)
    current_epoch, current_iteration = trainer.load_checkpoint(
        cfg, args.checkpoint,resume=args.resume)

    # Start training.
    for epoch in range(current_epoch, cfg.max_epoch):
        print('Epoch {} ...'.format(epoch))
        if not args.single_gpu:
            train_data_loader.sampler.set_epoch(current_epoch)
        trainer.start_of_epoch(current_epoch)
        for it, data in enumerate(train_data_loader):
            data = trainer.start_of_iteration(data, current_iteration)

            for _ in range(cfg.trainer.dis_step):
                trainer.dis_update(data)
            for _ in range(cfg.trainer.gen_step):
                trainer.gen_update(data)

            current_iteration += 1
            trainer.end_of_iteration(data, current_epoch, current_iteration)
            if current_iteration >= cfg.max_iter:
                print('Done with training!!!')
                return

        current_epoch += 1
        trainer.end_of_epoch(data, current_epoch, current_iteration)
    print('Done with training!!!')
    return
Example #2
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)
    if not hasattr(cfg, 'inference_args'):
        cfg.inference_args = None

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel.
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    test_data_loader = get_test_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          None, test_data_loader)

    # if args.checkpoint == '':
    #     # Download pretrained weights.
    #     pretrained_weight_url = cfg.pretrained_weight
    #     if pretrained_weight_url == '':
    #         print('google link to the pretrained weight is not specified.')
    #         raise
    #     default_checkpoint_path = args.config.replace('.yaml', '.pt')
    #     args.checkpoint = get_checkpoint(
    #         default_checkpoint_path, pretrained_weight_url)
    #     print('Checkpoint downloaded to', args.checkpoint)

    # Load checkpoint.
    trainer.load_checkpoint(cfg, args.checkpoint)

    # Do inference.
    trainer.current_epoch = -1
    trainer.current_iteration = -1
    trainer.test(test_data_loader, args.output_dir, cfg.inference_args)
Example #3
0
def main():
    args = parse_args()
    set_affinity(args.local_rank)
    set_random_seed(args.seed, by_rank=True)
    cfg = Config(args.config)

    # If args.single_gpu is set to True,
    # we will disable distributed data parallel
    if not args.single_gpu:
        cfg.local_rank = args.local_rank
        init_dist(cfg.local_rank)

    # Override the number of data loading workers if necessary
    if args.num_workers is not None:
        cfg.data.num_workers = args.num_workers

    # Create log directory for storing training results.
    cfg.date_uid, cfg.logdir = init_logging(args.config, args.logdir)
    make_logging_dir(cfg.logdir)

    # Initialize cudnn.
    init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)

    # Initialize data loaders and models.
    train_data_loader, val_data_loader = get_train_and_val_dataloader(cfg)
    net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
        get_model_optimizer_and_scheduler(cfg, seed=args.seed)
    trainer = get_trainer(cfg, net_G, net_D,
                          opt_G, opt_D,
                          sch_G, sch_D,
                          train_data_loader, val_data_loader)

    # Start evaluation.
    checkpoints = \
        sorted(glob.glob('{}/*.pt'.format(args.checkpoint_logdir)))
    for checkpoint in checkpoints:
        current_epoch, current_iteration = \
            trainer.load_checkpoint(cfg, checkpoint, resume=True)
        trainer.current_epoch = current_epoch
        trainer.current_iteration = current_iteration
        trainer.write_metrics()
    print('Done with evaluation!!!')
    return
Example #4
0
    def _init_single_image_model(self, load_weights=True):
        r"""Load single image model, if any."""
        if self.single_image_model is None and \
                hasattr(self.gen_cfg, 'single_image_model'):
            print('Using single image model...')
            single_image_cfg = Config(self.gen_cfg.single_image_model.config)

            # Init model.
            net_G, net_D, opt_G, opt_D, sch_G, sch_D = \
                get_model_optimizer_and_scheduler(single_image_cfg)

            # Init trainer and load checkpoint.
            trainer = get_trainer(single_image_cfg, net_G, net_D, opt_G, opt_D,
                                  sch_G, sch_D, None, None)
            if load_weights:
                print('Loading single image model checkpoint')
                single_image_ckpt = self.gen_cfg.single_image_model.checkpoint
                trainer.load_checkpoint(single_image_cfg, single_image_ckpt)
                print('Loaded single image model checkpoint')

            self.single_image_model = net_G.module
            self.single_image_model_z = None