def main(): args = parse_args() cfg = Config.fromfile(args.config) init_dist(**cfg.dist_params) if "stages" in cfg: cfg = cfg.stages[args.stage] experiment = CIFARExperiment(cfg) runner = Runner( model=experiment.model, optimizers=experiment.optimizers, schedulers=experiment.schedulers, batch_processor=CIFARBatchProcessor(cfg), hooks=experiment.hooks, work_dir=experiment.work_dir, ) runner.run( data_loaders={ "train": experiment.dataloader("train"), "val": experiment.dataloader("val") }, max_epochs=cfg.max_epochs, )
def main(): args = parse_args() cfg = Config.fromfile(args.config) init_dist(**cfg.dist_params) experiment = GANExperiment(cfg) runner = Runner(model=experiment.model, optimizers=experiment.optimizers, batch_processor=GANBatchProcessor(cfg), hooks=experiment.hooks, work_dir=experiment.work_dir) runner.run(data_loaders={'train': experiment.dataloader('train')}, max_epochs=cfg.total_epochs, resume_from=cfg.resume_from, load_from=cfg.load_from)
def main(): args = parse_args() cfg = Config.fromfile(args.config) if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True init_dist(**cfg.dist_params) experiment = CIFARExperiment(cfg) # build data loader data_loader = experiment.dataloader("val") # build model model = experiment.model load_checkpoint(model, args.checkpoint, map_location="cpu") batch_processor = CIFARBatchProcessor(cfg) outputs = multi_gpu_test(model, data_loader, batch_processor) outputs = collect_results(outputs, len(data_loader.dataset)) io.dump(outputs, args.out)
if dist.is_initialized(): builder = TileClassifierDDPBuilder(config) else: builder = TileClassifierDPBuilder(config) data_loaders = {x: builder.data_loader(x) for x in config.DATA.keys()} batch_processor = TileClassifierBatchProcessor(builder) runner = Runner( model=builder.model, optimizers=builder.optimizers, schedulers=builder.schedulers, hooks=builder.hooks, work_dir=builder.config.WORK_DIR, batch_processor=batch_processor, ) runner.run(data_loaders=data_loaders, max_epochs=builder.config.MAX_EPOCHS) if __name__ == "__main__": logging.getLogger().addHandler(logging.StreamHandler()) args = parse_args() config = Config.fromfile(args.config_path) if args.is_distributed: init_dist(**config.DIST_PARAMS) train_func = locals()[config.TRAIN_FUNC] train_func(config)