def main(): args = vars(parser.parse_args()) check_args(args) set_seeds(2021) data_cfg = config.DataConfig(args["data_config"]) model_cfg = config.ModelConfig(args["model_config"]) run_cfg = config.RunConfig(args["run_config"], eval=True, sanity_check=args["sanity_check"]) output, save_prefix = set_output(args, "evaluate_model_log") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config.print_configs(args, [data_cfg, model_cfg, run_cfg], device, output) torch.zeros((1)).to(device) ## Loading a dataset start = Print(" ".join(['start loading a dataset']), output) dataset_test = get_dataset_from_configs(data_cfg, "test", model_cfg.embedder, sanity_check=args["sanity_check"]) iterator_test = torch.utils.data.DataLoader(dataset_test, run_cfg.batch_size, shuffle=False, pin_memory=True, num_workers=4) end = Print( " ".join(['loaded', str(len(dataset_test)), 'dataset_test samples']), output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) ## initialize a model start = Print('start initializing a model', output) model, params = get_model(model_cfg, run_cfg) get_profile(model, dataset_test, output) end = Print('end initializing a model', output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) ## setup trainer configurations start = Print('start setting trainer configurations', output) trainer = Trainer(model) trainer.load_model(args["checkpoint"], output) trainer.set_device(device) end = Print('end setting trainer configurations', output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) ## evaluate a model start = Print('start evaluating a model', output) trainer.headline(output) ### validation for B, batch in enumerate(iterator_test): trainer.evaluate(batch, device) if B % 5 == 0: print('# {:.1%}'.format(B / len(iterator_test)), end='\r', file=sys.stderr) print(' ' * 150, end='\r', file=sys.stderr) ### print log trainer.save_outputs(save_prefix) trainer.log(data_cfg.data_idx, output) end = Print('end evaluating a model', output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) if not output == sys.stdout: output.close()
def main(): args = vars(parser.parse_args()) check_args(args) set_seeds(2021) data_cfg = config.DataConfig(args["data_config"]) model_cfg = config.ModelConfig(args["model_config"]) run_cfg = config.RunConfig(args["run_config"], eval=False, sanity_check=args["sanity_check"]) output, save_prefix = set_output(args, "train_model_log") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config.print_configs(args, [data_cfg, model_cfg, run_cfg], device, output) torch.zeros((1)).to(device) ## Loading a dataset start = Print(" ".join(['start loading a dataset']), output) dataset_train = get_dataset_from_configs(data_cfg, "train", model_cfg.embedder, sanity_check=args["sanity_check"]) iterator_train = torch.utils.data.DataLoader(dataset_train, run_cfg.batch_size, shuffle=True, pin_memory=True, num_workers=4) end = Print( " ".join(['loaded', str(len(dataset_train)), 'dataset_train samples']), output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) ## initialize a model start = Print('start initializing a model', output) model, params = get_model(model_cfg, run_cfg) get_profile(model, dataset_train, output) end = Print('end initializing a model', output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) ## setup trainer configurations start = Print('start setting trainer configurations', output) trainer = Trainer(model) trainer.load_model(args["checkpoint"], output) trainer.set_class_weight(dataset_train.labels, run_cfg) trainer.set_device(device) trainer.set_optim_scheduler(run_cfg, params) end = Print('end setting trainer configurations', output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) ## train a model start = Print('start training a model', output) trainer.headline(output) for epoch in range(int(trainer.epoch), run_cfg.num_epochs): ### train for B, batch in enumerate(iterator_train): trainer.train(batch, device) if B % 5 == 0: print('# epoch [{}/{}] train {:.1%}'.format( epoch + 1, run_cfg.num_epochs, B / len(iterator_train)), end='\r', file=sys.stderr) print(' ' * 150, end='\r', file=sys.stderr) ### print log and save models trainer.epoch += 1 trainer.save_model(save_prefix) trainer.log(data_cfg.data_idx, output) end = Print('end training a model', output) Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True) if not output == sys.stdout: output.close()