def test_load_checkpoint_with_different_num_classes(dirname): model = DummyPretrainedModel() to_save_single_object = {"model": model} trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) handler(trainer, to_save_single_object) fname = handler.last_checkpoint loaded_checkpoint = torch.load(fname) to_load_single_object = {"pretrained_features": model.features} with pytest.raises(RuntimeError): Checkpoint.load_objects(to_load_single_object, loaded_checkpoint) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) Checkpoint.load_objects(to_load_single_object, loaded_checkpoint, strict=False, blah="blah") loaded_weights = to_load_single_object["pretrained_features"].state_dict( )["weight"] assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
def _resume_training(resume_from: Union[str, Path], to_save: Dict[str, Any]): if resume_from: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), f'Checkpoint "{checkpoint_fp}" is not found' print(f'Resuming from a checkpoint: {checkpoint_fp}') checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
def _load_model(self): model = get_model(self.config) model.to(self.device) checkpoint = torch.load(self.checkpoint_path, map_location=self.device) Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint) model.eval() return model
def __call__(self, engine): checkpoint = torch.load(self.load_path) if len(self.load_dict) == 1: key = list(self.load_dict.keys())[0] if not (key in checkpoint): checkpoint = {key: checkpoint} Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) self.logger.info(f"Restored all variables from {self.load_path}")
def resume(self): d = Path(self.save_path) pattern = "checkpoint_*.pth" saves = list(d.glob(pattern)) if len(saves) == 0: raise FileNotFoundError("No checkpoint to load in %s" % (self.save_path)) fp = max(saves, key=lambda f: f.stat().st_mtime) checkpoint = torch.load(fp) Checkpoint.load_objects(self.to_save(), checkpoint) print("Load trainer from %s" % fp)
def get_model(self, model_name, device, prefix='', path=None): if path is None: path = join(Constants.MODELS_PATH, model_name) best_file, best_loss = get_model_filename(path, prefix) model = copy.deepcopy(self.my_models[model_name]) to_load = {'model': model} checkpoint = torch.load(join(path, best_file), map_location=device) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) return model, best_loss, self.get_thresholds(model_name)
def extract_model(ckp_file, device='cuda' if torch.cuda.is_available() else 'cpu'): tokenizer = BertTokenizer.from_pretrained(base_model) model = BertClassificationModel(cls=tokenizer.vocab_size, model_file=base_model) to_load = {'BertClassificationModel': model} checkpoint = torch.load(ckp_file, map_location=device) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) model.to(device) model.bert.save_pretrained('./extract_bert/')
def create_trainer_and_evaluators( model: nn.Module, optimizer: Optimizer, criterion: nn.Module, data_loaders: Dict[str, DataLoader], metrics: Dict[str, Metric], config: ConfigSchema, logger: Logger, ) -> Tuple[Engine, Dict[str, Engine]]: trainer = get_trainer(model, criterion, optimizer) trainer.logger = logger evaluators = get_evaluators(model, metrics) setup_evaluation(trainer, evaluators, data_loaders, logger) lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"]) to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } common.setup_common_training_handlers( trainer=trainer, to_save=to_save, save_every_iters=config.checkpoint_every, save_handler=get_save_handler(config), with_pbars=False, train_sampler=data_loaders["train"].sampler, ) trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters), ) resume_from = config.resume_from if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer, evaluators
def load_trainer_from_checkpoint(self): if self.hparams.checkpoint_dir is not None: if not self.hparams.load_model_only: objects_to_checkpoint = { "trainer": self.trainer, "model": self.model, "optimizer": self.optimizer, "scheduler": self.scheduler } if USE_AMP: objects_to_checkpoint["amp"] = amp else: objects_to_checkpoint = {"model": self.model} objects_to_checkpoint = {k: v for k, v in objects_to_checkpoint.items() if v is not None} checkpoint = torch.load(self.hparams.checkpoint_dir, map_location="cpu") Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint)
def _test_checkpoint_load_objects_ddp(device): model = DummyModel().to(device) device_ids = ( None if "cpu" in device.type else [device,] ) ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) opt = torch.optim.SGD(ddp_model.parameters(), lr=0.01) # single object: to_load = {"model": ddp_model} checkpoint = ddp_model.module.state_dict() Checkpoint.load_objects(to_load, checkpoint) # multiple objects: to_load = {"model": ddp_model, "opt": opt} checkpoint = {"model": ddp_model.module.state_dict(), "opt": opt.state_dict()} Checkpoint.load_objects(to_load, checkpoint)
def resume_from_checkpoint(to_save, conf, device=None): # type: (Dict[str, Any], DictConfig, Device) -> None to_load = {k: v for k, v in to_save.items() if v is not None} if conf.drop_state: # we might want to swap optimizer or to reset it state drop_keys = set(conf.drop_state) to_load = {k: v for k, v in to_load.items() if k not in drop_keys} checkpoint = torch.load(conf.load, map_location=device) ema_key = "model_ema" if ema_key in to_load and ema_key not in checkpoint: checkpoint[ema_key] = checkpoint["model"] logging.warning("There are no EMA weights in the checkpoint. " "Using saved model weights as a starting point for the EMA.") Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
def test_checkpoint_load_objects_from_saved_file(dirname): def _get_single_obj_to_save(): model = DummyModel() to_save = { "model": model, } return to_save def _get_multiple_objs_to_save(): model = DummyModel() optim = torch.optim.SGD(model.parameters(), lr=0.001) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5) to_save = { "model": model, "optimizer": optim, "lr_scheduler": lr_scheduler, } return to_save trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) # case: multiple objects handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_multiple_objs_to_save() handler(trainer, to_save) fname = handler.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) Checkpoint.load_objects(to_save, loaded_objects) os.remove(fname) # case: saved multiple objects, loaded single object handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_multiple_objs_to_save() handler(trainer, to_save) fname = handler.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) to_load = {'model': to_save['model']} Checkpoint.load_objects(to_load, loaded_objects) os.remove(fname) # case: single object handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) to_save = _get_single_obj_to_save() handler(trainer, to_save) fname = handler.last_checkpoint assert isinstance(fname, str) assert os.path.join(dirname, _PREFIX) in fname assert os.path.exists(fname) loaded_objects = torch.load(fname) Checkpoint.load_objects(to_save, loaded_objects)
def resume(self, fp=None): assert self._traier_state == TrainerState.INIT if fp is None: d = Path(self.save_path) pattern = "checkpoint_*.pt*" saves = list(d.glob(pattern)) if len(saves) == 0: raise FileNotFoundError("No checkpoint to load in %s" % self.save_path) fp = max(saves, key=lambda f: f.stat().st_mtime) checkpoint = torch.load(fp) if not self.fp16 and 'amp' in checkpoint: del checkpoint['amp'] Checkpoint.load_objects(self.to_save(), checkpoint) self._train_engine_state = checkpoint['train_engine'] self._eval_engine_state = checkpoint['eval_engine'] self._traier_state = TrainerState.FITTING print("Load trainer from %s" % fp)
def inference(test_loader, metircs): if os.listdir(opt.checkpoint_dir): for root, dirs, files in os.walk(opt.checkpoint_dir): checkpoint_file = root + '\\' + files[-1] checkpoint = torch.load(checkpoint_file) model = tv.models.resnet50() model.fc = nn.Linear(2048, 2) object_to_checkpoint = {'model': model} Checkpoint.load_objects(to_load=object_to_checkpoint, checkpoint=checkpoint) test = create_supervised_evaluator(model, metrics) @test.on(Events.COMPLETED) def get_result(engine): preds, labels = test.state.metrics['test_result'] preds = preds.reshape((-1, 1)) labels = labels.reshape((-1, 1)) preds = np.clip(preds, 0.005, 0.995) result = np.concatenate((labels, preds), axis=1) result = pd.DataFrame(result, columns=['id', 'label']) print(result) result.to_csv("..\\result.csv", index=None) return result test.run(test_loader)
def test_checkpoint_load_objects(): with pytest.raises(TypeError, match=r"Argument checkpoint should be a dictionary"): Checkpoint.load_objects({}, []) with pytest.raises(TypeError, match=r"should have `load_state_dict` method"): Checkpoint.load_objects({"a": None}, {"a": None}) model = DummyModel() to_load = {'model': model} with pytest.raises(ValueError, match=r"from `to_load` is not found in the checkpoint"): Checkpoint.load_objects(to_load, {}) model = DummyModel() to_load = {'model': model} model2 = DummyModel() chkpt = {'model': model2.state_dict()} Checkpoint.load_objects(to_load, chkpt) assert model.state_dict() == model2.state_dict()
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations with_amp = config["with_amp"] scaler = GradScaler(enabled=with_amp) def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() with autocast(enabled=with_amp): y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names if config["log_every_iters"] > 0 else None, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
resume =Events.ITERATION_STARTED, pause =Events.ITERATION_COMPLETED, step =Events.ITERATION_COMPLETED) train_timer.attach(trainer, start =Events.EPOCH_STARTED, resume =Events.EPOCH_STARTED, pause =Events.EPOCH_COMPLETED, step =Events.EPOCH_COMPLETED) if len(args.load_model) > 0: load_model_path = args.load_model print("load mode " + load_model_path) to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer} checkpoint = torch.load(load_model_path) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) print("load model complete") for param_group in optimizer.param_groups: param_group['lr'] = args.lr print("change lr to ", args.lr) else: print("do not load, keep training") @trainer.on(Events.ITERATION_COMPLETED(every=100)) def log_training_loss(trainer): timestamp = get_readable_time() print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) @trainer.on(Events.EPOCH_COMPLETED)
def run( train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir, checkpoint_every, resume_from, crash_iteration=1000, ): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() writer = SummaryWriter(log_dir=log_dir) device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer criterion = nn.NLLLoss() optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5) trainer = create_supervised_trainer(model, optimizer, criterion, device=device) evaluator = create_supervised_evaluator(model, metrics={ "accuracy": Accuracy(), "nll": Loss(criterion) }, device=device) @trainer.on(Events.EPOCH_COMPLETED) def lr_step(engine): lr_scheduler.step() desc = "ITERATION - loss: {:.4f} - lr: {:.4f}" pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0, lr)) if log_interval is None: e = Events.ITERATION_COMPLETED log_interval = 1 else: e = Events.ITERATION_COMPLETED(every=log_interval) @trainer.on(e) def log_training_loss(engine): lr = optimizer.param_groups[0]["lr"] pbar.desc = desc.format(engine.state.output, lr) pbar.update(log_interval) writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) writer.add_scalar("lr", lr, engine.state.iteration) if resume_from is None: @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration)) def _(engine): raise Exception("STOP at {}".format(engine.state.iteration)) else: @trainer.on(Events.STARTED) def _(engine): pbar.n = engine.state.iteration @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] tqdm.write( "Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_nll)) writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] tqdm.write( "Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_nll)) pbar.n = pbar.last_print_n = 0 writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) objects_to_checkpoint = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } training_checkpoint = Checkpoint(to_save=objects_to_checkpoint, save_handler=DiskSaver( log_dir, require_empty=False)) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=checkpoint_every), training_checkpoint) if resume_from is not None: tqdm.write("Resume from a checkpoint: {}".format(resume_from)) checkpoint = torch.load(resume_from) Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) try: trainer.run(train_loader, max_epochs=epochs) except Exception as e: import traceback print(traceback.format_exc()) pbar.close() writer.close()
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml): device = config.device prepare_batch = data.prepare_image_mask # Setup trainer accumulation_steps = config.get("accumulation_steps", 1) model_output_transform = config.get("model_output_transform", lambda x: x) with_amp = config.get("with_amp", True) scaler = GradScaler(enabled=with_amp) def forward_pass(batch): model.train() x, y = prepare_batch(batch, device=device, non_blocking=True) with autocast(enabled=with_amp): y_pred = model(x) y_pred = model_output_transform(y_pred) loss = criterion(y_pred, y) / accumulation_steps return loss def amp_backward_pass(engine, loss): scaler.scale(loss).backward() if engine.state.iteration % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() def hvd_amp_backward_pass(engine, loss): scaler.scale(loss).backward() optimizer.synchronize() with optimizer.skip_synchronize(): scaler.step(optimizer) scaler.update() optimizer.zero_grad() if idist.backend() == "horovod" and with_amp: backward_pass = hvd_amp_backward_pass else: backward_pass = amp_backward_pass def training_step(engine, batch): loss = forward_pass(batch) output = {"supervised batch loss": loss.item()} backward_pass(engine, loss) return output trainer = Engine(training_step) trainer.logger = logger output_names = [ "supervised batch loss", ] lr_scheduler = config.lr_scheduler to_save = { "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "trainer": trainer, "amp": scaler, } save_every_iters = config.get("save_every_iters", 1000) common.setup_common_training_handlers( trainer, train_sampler, to_save=to_save, save_every_iters=save_every_iters, save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml), lr_scheduler=lr_scheduler, output_names=output_names, with_pbars=not with_clearml, log_every_iters=1, ) resume_from = config.get("resume_from", None) if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def main(): # region Setup conf = parse_args() setup_seeds(conf.session.seed) tb_logger, tb_img_logger, json_logger = setup_all_loggers(conf) logger.info("Parsed configuration:\n" + pyaml.dump(OmegaConf.to_container(conf), safe=True, sort_dicts=False, force_embed=True)) # region Predicate classification engines datasets, dataset_metadata = build_datasets(conf.dataset) dataloaders = build_dataloaders(conf, datasets) model = build_model(conf.model, dataset_metadata["train"]).to(conf.session.device) criterion = PredicateClassificationCriterion(conf.losses) pred_class_trainer = Trainer(pred_class_training_step, conf) pred_class_trainer.model = model pred_class_trainer.criterion = criterion pred_class_trainer.optimizer, scheduler = build_optimizer_and_scheduler( conf.optimizer, pred_class_trainer.model) pred_class_validator = Validator(pred_class_validation_step, conf) pred_class_validator.model = model pred_class_validator.criterion = criterion pred_class_tester = Validator(pred_class_validation_step, conf) pred_class_tester.model = model pred_class_tester.criterion = criterion # endregion if "resume" in conf: checkpoint = Path(conf.resume.checkpoint).expanduser().resolve() logger.debug(f"Resuming checkpoint from {checkpoint}") Checkpoint.load_objects( { "model": pred_class_trainer.model, "optimizer": pred_class_trainer.optimizer, "scheduler": scheduler, "trainer": pred_class_trainer, }, checkpoint=torch.load(checkpoint, map_location=conf.session.device), ) logger.info(f"Resumed from {checkpoint}, " f"epoch {pred_class_trainer.state.epoch}, " f"samples {pred_class_trainer.global_step()}") # endregion # region Predicate classification training callbacks def increment_samples(trainer: Trainer): images = trainer.state.batch[0] trainer.state.samples += len(images) pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED, increment_samples) ProgressBar(persist=True, desc="Pred class train").attach( pred_class_trainer, output_transform=itemgetter("losses")) tb_logger.attach( pred_class_trainer, OptimizerParamsHandler( pred_class_trainer.optimizer, param_name="lr", tag="z", global_step_transform=pred_class_trainer.global_step, ), Events.EPOCH_STARTED, ) pred_class_trainer.add_event_handler( Events.ITERATION_COMPLETED, PredicateClassificationMeanAveragePrecisionBatch()) pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED, RecallAtBatch(sizes=(5, 10))) tb_logger.attach( pred_class_trainer, OutputHandler( "train", output_transform=lambda o: { **o["losses"], "pc/mAP": o["pc/mAP"].mean().item(), **{k: r.mean().item() for k, r in o["recalls"].items()}, }, global_step_transform=pred_class_trainer.global_step, ), Events.ITERATION_COMPLETED, ) pred_class_trainer.add_event_handler( Events.EPOCH_COMPLETED, log_metrics, "Predicate classification training", "train", json_logger=None, tb_logger=tb_logger, global_step_fn=pred_class_trainer.global_step, ) pred_class_trainer.add_event_handler( Events.EPOCH_COMPLETED, PredicateClassificationLogger( grid=(2, 3), tag="train", logger=tb_img_logger.writer, metadata=dataset_metadata["train"], global_step_fn=pred_class_trainer.global_step, ), ) tb_logger.attach( pred_class_trainer, EpochHandler( pred_class_trainer, tag="z", global_step_transform=pred_class_trainer.global_step, ), Events.EPOCH_COMPLETED, ) pred_class_trainer.add_event_handler( Events.EPOCH_COMPLETED, lambda _: pred_class_validator.run(dataloaders["val"])) # endregion # region Predicate classification validation callbacks ProgressBar(persist=True, desc="Pred class val").attach(pred_class_validator) if conf.losses["bce"]["weight"] > 0: Average(output_transform=lambda o: o["losses"]["loss/bce"]).attach( pred_class_validator, "loss/bce") if conf.losses["rank"]["weight"] > 0: Average(output_transform=lambda o: o["losses"]["loss/rank"]).attach( pred_class_validator, "loss/rank") Average(output_transform=lambda o: o["losses"]["loss/total"]).attach( pred_class_validator, "loss/total") PredicateClassificationMeanAveragePrecisionEpoch( itemgetter("target", "output")).attach(pred_class_validator, "pc/mAP") RecallAtEpoch((5, 10), itemgetter("target", "output")).attach(pred_class_validator, "pc/recall_at") pred_class_validator.add_event_handler( Events.EPOCH_COMPLETED, lambda val_engine: scheduler.step(val_engine.state.metrics["loss/total" ]), ) pred_class_validator.add_event_handler( Events.EPOCH_COMPLETED, log_metrics, "Predicate classification validation", "val", json_logger, tb_logger, pred_class_trainer.global_step, ) pred_class_validator.add_event_handler( Events.EPOCH_COMPLETED, PredicateClassificationLogger( grid=(2, 3), tag="val", logger=tb_img_logger.writer, metadata=dataset_metadata["val"], global_step_fn=pred_class_trainer.global_step, ), ) pred_class_validator.add_event_handler( Events.COMPLETED, EarlyStopping( patience=conf.session.early_stopping.patience, score_function=lambda val_engine: -val_engine.state.metrics[ "loss/total"], trainer=pred_class_trainer, ), ) pred_class_validator.add_event_handler( Events.COMPLETED, Checkpoint( { "model": pred_class_trainer.model, "optimizer": pred_class_trainer.optimizer, "scheduler": scheduler, "trainer": pred_class_trainer, }, DiskSaver( Path(conf.checkpoint.folder).expanduser().resolve() / conf.fullname), score_function=lambda val_engine: val_engine.state.metrics[ "pc/recall_at_5"], score_name="pc_recall_at_5", n_saved=conf.checkpoint.keep, global_step_transform=pred_class_trainer.global_step, ), ) # endregion if "test" in conf.dataset: # region Predicate classification testing callbacks if conf.losses["bce"]["weight"] > 0: Average( output_transform=lambda o: o["losses"]["loss/bce"], device=conf.session.device, ).attach(pred_class_tester, "loss/bce") if conf.losses["rank"]["weight"] > 0: Average( output_transform=lambda o: o["losses"]["loss/rank"], device=conf.session.device, ).attach(pred_class_tester, "loss/rank") Average( output_transform=lambda o: o["losses"]["loss/total"], device=conf.session.device, ).attach(pred_class_tester, "loss/total") PredicateClassificationMeanAveragePrecisionEpoch( itemgetter("target", "output")).attach(pred_class_tester, "pc/mAP") RecallAtEpoch((5, 10), itemgetter("target", "output")).attach(pred_class_tester, "pc/recall_at") ProgressBar(persist=True, desc="Pred class test").attach(pred_class_tester) pred_class_tester.add_event_handler( Events.EPOCH_COMPLETED, log_metrics, "Predicate classification test", "test", json_logger, tb_logger, pred_class_trainer.global_step, ) pred_class_tester.add_event_handler( Events.EPOCH_COMPLETED, PredicateClassificationLogger( grid=(2, 3), tag="test", logger=tb_img_logger.writer, metadata=dataset_metadata["test"], global_step_fn=pred_class_trainer.global_step, ), ) # endregion # region Run log_effective_config(conf, pred_class_trainer, tb_logger) if not ("resume" in conf and conf.resume.test_only): max_epochs = conf.session.max_epochs if "resume" in conf: max_epochs += pred_class_trainer.state.epoch pred_class_trainer.run( dataloaders["train"], max_epochs=max_epochs, seed=conf.session.seed, epoch_length=len(dataloaders["train"]), ) if "test" in conf.dataset: pred_class_tester.run(dataloaders["test"]) add_session_end(tb_logger.writer, "SUCCESS") tb_logger.close() tb_img_logger.close()
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration if (config["log_every_iters"] > 0 and (engine.state.iteration - 1) % config["log_every_iters"] == 0): batch_loss = loss.item() engine.state.saved_batch_loss = batch_loss else: batch_loss = engine.state.saved_batch_loss return {"batch loss": batch_loss} trainer = Engine(train_step) trainer.state.saved_batch_loss = -1.0 trainer.state_dict_user_keys.append("saved_batch_loss") trainer.logger = logger to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } metric_names = ["batch loss"] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names if config["log_every_iters"] > 0 else None, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists( ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def run( train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir, checkpoint_every, resume_from, crash_iteration=-1, deterministic=False, ): # Setup seed to have same model's initialization: manual_seed(75) train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() writer = SummaryWriter(log_dir=log_dir) device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer criterion = nn.NLLLoss() optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5) # Setup trainer and evaluator if deterministic: tqdm.write("Setup deterministic trainer") trainer = create_supervised_trainer(model, optimizer, criterion, device=device, deterministic=deterministic) evaluator = create_supervised_evaluator(model, metrics={ "accuracy": Accuracy(), "nll": Loss(criterion) }, device=device) # Apply learning rate scheduling @trainer.on(Events.EPOCH_COMPLETED) def lr_step(engine): lr_scheduler.step() pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=f"Epoch {0} - loss: {0:.4f} - lr: {lr:.4f}") @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) def log_training_loss(engine): lr = optimizer.param_groups[0]["lr"] pbar.desc = f"Epoch {engine.state.epoch} - loss: {engine.state.output:.4f} - lr: {lr:.4f}" pbar.update(log_interval) writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) writer.add_scalar("lr", lr, engine.state.iteration) if crash_iteration > 0: @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration)) def _(engine): raise Exception(f"STOP at {engine.state.iteration}") if resume_from is not None: @trainer.on(Events.STARTED) def _(engine): pbar.n = engine.state.iteration % engine.state.epoch_length @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] tqdm.write( f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}" ) writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) # Compute and log validation metrics @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] tqdm.write( f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}" ) pbar.n = pbar.last_print_n = 0 writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) # Setup object to checkpoint objects_to_checkpoint = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } training_checkpoint = Checkpoint( to_save=objects_to_checkpoint, save_handler=DiskSaver(log_dir, require_empty=False), n_saved=None, global_step_transform=lambda *_: trainer.state.epoch, ) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_every), training_checkpoint) # Setup logger to print and dump into file: model weights, model grads and data stats # - first 3 iterations # - 4 iterations after checkpointing # This helps to compare resumed training with checkpointed training def log_event_filter(e, event): if event in [1, 2, 3]: return True elif 0 <= (event % (checkpoint_every * e.state.epoch_length)) < 5: return True return False fp = Path(log_dir) / ("run.log" if resume_from is None else "resume_run.log") fp = fp.as_posix() for h in [log_data_stats, log_model_weights, log_model_grads]: trainer.add_event_handler( Events.ITERATION_COMPLETED(event_filter=log_event_filter), h, model=model, fp=fp) if resume_from is not None: tqdm.write(f"Resume from the checkpoint: {resume_from}") checkpoint = torch.load(resume_from) Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) try: # Synchronize random states manual_seed(15) trainer.run(train_loader, max_epochs=epochs) except Exception as e: import traceback print(traceback.format_exc()) pbar.close() writer.close()
def setup(self): self.trainer = Engine(self.train_step) self.evaluator = Engine(self.eval_step) self.logger._init_logger(self.trainer, self.evaluator) # TODO: Multi-gpu support self.model.to(self.device) if self.use_amp: self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") if self.checkpoint_dir is not None: if not self.load_model_only: objects_to_checkpoint = { "trainer": self.trainer, "model": self.model, "optimizer": self.optimizer, "scheduler": self.scheduler } if self.use_amp: objects_to_checkpoint["amp"] = amp else: objects_to_checkpoint = {"model": self.model} objects_to_checkpoint = { k: v for k, v in objects_to_checkpoint.items() if v is not None } checkpoint = torch.load(self.checkpoint_dir) Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) train_handler_params = { "model": self.model, "optimizer": self.optimizer, "scheduler": self.scheduler, "metrics": self.train_metrics, "add_pbar": self.add_pbar } self.logger._add_train_events(**train_handler_params) to_save = { "model": self.model, "trainer": self.trainer, "optimizer": self.optimizer, "scheduler": self.scheduler } if self.use_amp: to_save["amp"] = amp eval_handler_params = { "metrics": self.validation_metrics, "validloader": self.val_loader, "to_save": to_save, "add_pbar": self.add_pbar } eval_handler_params["to_save"] = { k: v for k, v in eval_handler_params["to_save"].items() if v is not None } self.logger._add_eval_events(**eval_handler_params) if self.scheduler: self.trainer.add_event_handler(Events.ITERATION_STARTED, self.scheduler) self.trainer.logger = setup_logger("trainer") self.evaluator.logger = setup_logger("evaluator")
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--model", type=str, default='ffn', help="model's name") parser.add_argument("--mode", type=int, choices=[0, 1, 2], default=None) parser.add_argument("--SNRdb", type=float, default=None) parser.add_argument("--pilot_version", type=int, choices=[1, 2], default=1) parser.add_argument("--loss_type", type=str, default="BCELoss") parser.add_argument("--train_batch_size", type=int, default=128) parser.add_argument("--valid_batch_size", type=int, default=128) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--max_norm", type=float, default=-1) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--noise_lambda", type=float, default=1.0) parser.add_argument("--lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default="linear") parser.add_argument("--reset_lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default=None) parser.add_argument("--reset_trainer", action='store_true') parser.add_argument("--modify_model", action='store_true') parser.add_argument("--wd", type=float, default=1e-4, help="weight decay") parser.add_argument("--eval_iter", type=int, default=10) parser.add_argument("--save_iter", type=int, default=10) parser.add_argument("--n_epochs", type=int, default=10) parser.add_argument("--flush_dataset", type=int, default=0) parser.add_argument("--no_cache", action='store_true') parser.add_argument("--with_pure_y", action='store_true') parser.add_argument("--with_h", action='store_true') parser.add_argument("--only_l1", action='store_true', help="Only loss 1") parser.add_argument("--interpolation", action='store_true', help="if interpolate between pure and reconstruction.") parser.add_argument("--data_dir", type=str, default="data") parser.add_argument("--cache_dir", type=str, default="train_cache") parser.add_argument("--output_path", type=str, default="runs", help="model save") parser.add_argument("--resume_from", type=str, default=None, help="resume training.") parser.add_argument("--first_cache_index", type=int, default=0) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument("--seed", type=int, default=43) parser.add_argument("--debug", action='store_true') args = parser.parse_args() args.output_path = os.path.join(args.output_path, f'pilot_{args.pilot_version}') args.cache_dir = os.path.join(args.data_dir, args.cache_dir) # Setup CUDA, GPU & distributed training args.distributed = (args.local_rank != -1) if not args.distributed: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method='env://') args.n_gpu = torch.cuda.device_count() if not args.distributed else 1 args.device = device # Set seed set_seed(args) logger = setup_logger("trainer", distributed_rank=args.local_rank) # Model construction model = getattr(models, args.model)(args) model = model.to(device) optimizer = AdamW(model.parameters(), lr = args.lr, weight_decay=args.wd) if args.loss_type == "MSELoss": criterion = nn.MSELoss(reduction='sum').to(device) else: criterion = getattr(nn, args.loss_type, getattr(auxiliary, args.loss_type, None))().to(device) criterion2 = nn.MSELoss(reduction='sum').to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True ) train_dataset = SIGDataset(args, data_type="train") valid_dataset = SIGDataset(args, data_type="valid") train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True, shuffle=(not args.distributed)) valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, pin_memory=True, shuffle=False) lr_scheduler = None if args.lr_scheduler == "linear": lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) elif args.lr_scheduler == "cycle": lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader)) elif args.lr_scheduler == "cosine": lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader)) # Training function and trainer def update(engine, batch): model.train() y, x_label, y_pure, H = train_dataset.prepare_batch(batch, device=args.device) if args.with_pure_y and args.with_h: x_pred, y_pure_pred, H_pred = model(y, pure=y_pure, H=H, opp=True) loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps if args.loss_type == "MSELoss": loss_1 = loss_1 / x_pred.size(0) loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps loss_noise_h = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps if args.only_l1: loss = loss_1 else: loss = loss_1 + loss_noise * args.noise_lambda + loss_noise_h output = (loss.item(), loss_1.item(), loss_noise.item(), loss_noise_h.item()) elif args.with_pure_y: x_pred, y_pure_pred = model(y, pure=y_pure if args.interpolation else None, opp=True) loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps loss = loss_1 + loss_noise * args.noise_lambda output = (loss.item(), loss_1.item(), loss_noise.item()) elif args.with_h: x_pred, H_pred = model(y, opp=True) loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps loss_noise = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps loss = loss_1 + loss_noise * args.noise_lambda output = (loss.item(), loss_1.item(), loss_noise.item()) else: x_pred = model(y) loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps loss = loss_1 output = (loss.item(), loss_1.item(), torch.zeros_like(loss_1).item()) loss.backward() if args.max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return output trainer = Engine(update) to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} metric_names = ["loss", "l1", "ln"] if args.with_pure_y and args.with_h: metric_names.append("lnH") common.setup_common_training_handlers( trainer=trainer, train_sampler=train_loader.sampler, to_save=to_save, save_every_iters=len(train_loader) * args.save_iter, lr_scheduler=lr_scheduler, output_names=metric_names, with_pbars=False, clear_cuda_cache=False, output_path=args.output_path, n_saved=2, ) resume_from = args.resume_from if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix()) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") if args.reset_trainer: to_save.pop("trainer") checkpoint_to_load = to_save if 'validation' not in resume_from else {"model": model} Checkpoint.load_objects(to_load=checkpoint_to_load, checkpoint=checkpoint) if args.reset_lr_scheduler is not None: if args.reset_lr_scheduler == "linear": lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) elif args.reset_lr_scheduler == "cycle": lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader)) elif args.reset_lr_scheduler == "cosine": lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader)) metrics = { "accuracy": Accuracy(lambda output: (torch.round(output[0][0]), output[1][0])), "loss_1": Loss(criterion, output_transform=lambda output: (output[0][0], output[1][0])), "loss_noise": Loss(criterion2, output_transform=lambda output: (output[0][1], output[1][1])) } if args.with_pure_y and args.with_h: metrics["loss_noise_h"] = Loss(criterion2, output_transform=lambda output: (output[0][2], output[1][2])) def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: model.eval() with torch.no_grad(): x, y, x_pure, H = valid_dataset.prepare_batch(batch, device=args.device, non_blocking=True) if args.with_pure_y and args.with_h: y_pred, x_pure_pred, h_pred = model(x, opp=True) outputs = (y_pred, x_pure_pred, h_pred), (y, x_pure, H) elif args.with_pure_y: y_pred, x_pure_pred = model(x, opp=True) outputs = (y_pred, x_pure_pred), (y, x_pure) elif args.with_h: y_pred, h_pred = model(x, opp=True) outputs = (y_pred, h_pred), (y, H) else: y_pred = model(x) x_pure_pred = x_pure outputs = (y_pred, x_pure_pred), (y, x_pure) return outputs evaluator = Engine(_inference) for name, metric in metrics.items(): metric.attach(evaluator, name) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.eval_iter), lambda _: evaluator.run(valid_loader)) if args.flush_dataset > 0: trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.n_epochs//args.flush_dataset), lambda _: train_loader.dataset.reset() if args.no_cache else train_loader.dataset.reload()) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=metric_names, output_transform=lambda _: {"lr": f"{optimizer.param_groups[0]['lr']:.2e}"}) evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = common.setup_tb_logging(args.output_path, trainer, optimizer, evaluators={'validation': evaluator}, log_every_iters=1) # Store 3 best models by validation accuracy: common.gen_save_best_models_by_val_score( save_handler=DiskSaver(args.output_path, require_empty=False), evaluator=evaluator, models={"model": model}, metric_name="accuracy", n_saved=3, trainer=trainer, tag="validation" ) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) if args.local_rank in [-1, 0]: tb_logger.close()
def main(args): fix_seeds() # if os.path.exists('./logs'): # shutil.rmtree('./logs') # os.mkdir('./logs') # writer = SummaryWriter(log_dir='./logs') vis = visdom.Visdom() val_avg_loss_window = create_plot_window(vis, '#Epochs', 'Loss', 'Average Loss', legend=['Train', 'Val']) val_avg_accuracy_window = create_plot_window(vis, '#Epochs', 'Accuracy', 'Average Accuracy', legend=['Val']) size = (args.height, args.width) train_transform = transforms.Compose([ transforms.Resize(size), # transforms.RandomResizedCrop(size=size, scale=(0.5, 1)), transforms.RandomHorizontalFlip(), transforms.RandomAffine(10, translate=(0.1, 0.1), scale=(0.8, 1.2), resample=PIL.Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) train_dataset = TextDataset(args.data_path, 'train.txt', size=args.train_size, transform=train_transform) val_dataset = TextDataset(args.data_path, 'val.txt', size=args.val_size, transform=val_transform) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.workers, shuffle=False) model = models.resnet18(pretrained=False) model.fc = nn.Linear(512, 16) model.load_state_dict(torch.load(args.resume_from)['model']) device = 'cpu' if args.cuda: device = 'cuda' print(device) metrics = {'accuracy': Accuracy(), 'loss': Loss(criterion)} evaluator = create_supervised_evaluator(model, metrics, device=device) @trainer.on(Events.ITERATION_COMPLETED) def lr_step(engine): if model.training: scheduler.step() global pbar, desc pbar, desc = None, None @trainer.on(Events.EPOCH_STARTED) def create_train_pbar(engine): global desc, pbar if pbar is not None: pbar.close() desc = 'Train iteration - loss: {:.4f} - lr: {:.4f}' pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0, lr)) @trainer.on(Events.EPOCH_COMPLETED) def create_val_pbar(engine): global desc, pbar if pbar is not None: pbar.close() desc = 'Validation iteration - loss: {:.4f}' pbar = tqdm(initial=0, leave=False, total=len(val_loader), desc=desc.format(0)) # desc_val = 'Validation iteration - loss: {:.4f}' # pbar_val = tqdm(initial=0, leave=False, total=len(val_loader), desc=desc_val.format(0)) log_interval = 1 e = Events.ITERATION_COMPLETED(every=log_interval) train_losses = [] @trainer.on(e) def log_training_loss(engine): lr = optimizer.param_groups[0]['lr'] train_losses.append(engine.state.output) pbar.desc = desc.format(engine.state.output, lr) pbar.update(log_interval) # writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) # writer.add_scalar("lr", lr, engine.state.iteration) @evaluator.on(e) def log_validation_loss(engine): label = engine.state.batch[1].to(device) output = engine.state.output[0] pbar.desc = desc.format(criterion(output, label)) pbar.update(log_interval) # if args.resume_from is not None: # @trainer.on(Events.STARTED) # def _(engine): # pbar.n = engine.state.iteration # @trainer.on(Events.EPOCH_COMPLETED(every=1)) # def log_train_results(engine): # evaluator.run(train_loader) # eval on train set to check for overfitting # metrics = evaluator.state.metrics # avg_accuracy = metrics['accuracy'] # avg_nll = metrics['loss'] # tqdm.write( # "Train Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" # .format(engine.state.epoch, avg_accuracy, avg_nll)) # pbar.n = pbar.last_print_n = 0 @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): pbar.refresh() evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_nll = metrics['loss'] tqdm.write( "Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_nll)) # pbar.n = pbar.last_print_n = 0 # writer.add_scalars("avg losses", {"train": statistics.mean(train_losses), # "valid": avg_nll}, engine.state.epoch) # # writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch) # writer.add_scalar("avg_accuracy", avg_accuracy, engine.state.epoch) vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_accuracy]), win=val_avg_accuracy_window, update='append') vis.line(X=np.column_stack( (np.array([engine.state.epoch]), np.array([engine.state.epoch]))), Y=np.column_stack((np.array([statistics.mean(train_losses)]), np.array([avg_nll]))), win=val_avg_loss_window, update='append', opts=dict(legend=['Train', 'Val'])) del train_losses[:] objects_to_checkpoint = { "trainer": trainer, "model": model, "optimizer": optimizer, "scheduler": scheduler } training_checkpoint = Checkpoint(to_save=objects_to_checkpoint, save_handler=DiskSaver( args.snapshot_dir, require_empty=False)) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), training_checkpoint) if args.resume_from not in [None, '']: tqdm.write("Resume from a checkpoint: {}".format(args.resume_from)) checkpoint = torch.load(args.resume_from) Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) try: trainer.run(train_loader, max_epochs=args.epochs) pbar.close() except Exception as e: import traceback print(traceback.format_exc())
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations cutmix_beta = config["cutmix_beta"] cutmix_prob = config["cutmix_prob"] with_amp = config["with_amp"] scaler = GradScaler(enabled=with_amp) def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() with autocast(enabled=with_amp): r = torch.rand(1).item() if cutmix_beta > 0 and r < cutmix_prob: output, loss = utils.cutmix_forward(model, x, criterion, y, cutmix_beta) else: output = model(x) loss = criterion(output, y) optimizer.zero_grad() scaler.scale(loss).backward() if idist.backend() == "horovod": optimizer.synchronize() with optimizer.skip_synchronize(): scaler.step(optimizer) scaler.update() else: scaler.step(optimizer) scaler.update() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger if config["with_pbar"] and idist.get_rank() == 0: ProgressBar().attach(trainer) to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert (checkpoint_fp.exists() ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def create_trainer( train_step, output_names, model, ema_model, optimizer, lr_scheduler, supervised_train_loader, test_loader, cfg, logger, cta=None, unsup_train_loader=None, cta_probe_loader=None, ): trainer = Engine(train_step) trainer.logger = logger output_path = os.getcwd() to_save = { "model": model, "ema_model": ema_model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler, } if cta is not None: to_save["cta"] = cta common.setup_common_training_handlers( trainer, train_sampler=supervised_train_loader.sampler, to_save=to_save, save_every_iters=cfg.solver.checkpoint_every, output_path=output_path, output_names=output_names, lr_scheduler=lr_scheduler, with_pbars=False, clear_cuda_cache=False, ) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED ) unsupervised_train_loader_iter = None if unsup_train_loader is not None: unsupervised_train_loader_iter = cycle(unsup_train_loader) cta_probe_loader_iter = None if cta_probe_loader is not None: cta_probe_loader_iter = cycle(cta_probe_loader) # Setup handler to prepare data batches @trainer.on(Events.ITERATION_STARTED) def prepare_batch(e): sup_batch = e.state.batch e.state.batch = { "sup_batch": sup_batch, } if unsupervised_train_loader_iter is not None: unsup_batch = next(unsupervised_train_loader_iter) e.state.batch["unsup_batch"] = unsup_batch if cta_probe_loader_iter is not None: cta_probe_batch = next(cta_probe_loader_iter) cta_probe_batch["policy"] = [ deserialize(p) for p in cta_probe_batch["policy"] ] e.state.batch["cta_probe_batch"] = cta_probe_batch # Setup handler to update EMA model @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay) def update_ema_model(ema_decay): # EMA on parametes for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) # Setup handlers for debugging if cfg.debug: @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) @idist.one_rank_only() def log_weights_norms(): wn = [] ema_wn = [] for ema_param, param in zip(ema_model.parameters(), model.parameters()): wn.append(torch.mean(param.data)) ema_wn.append(torch.mean(ema_param.data)) msg = "\n\nWeights norms" msg += "\n- Raw model: {}".format( to_list_str(torch.tensor(wn[:10] + wn[-10:])) ) msg += "\n- EMA model: {}\n".format( to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])) ) logger.info(msg) rmn = [] rvar = [] ema_rmn = [] ema_rvar = [] for m1, m2 in zip(model.modules(), ema_model.modules()): if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): rmn.append(torch.mean(m1.running_mean)) rvar.append(torch.mean(m1.running_var)) ema_rmn.append(torch.mean(m2.running_mean)) ema_rvar.append(torch.mean(m2.running_var)) msg = "\n\nBN buffers" msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10]))) msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10]))) msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10]))) msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10]))) logger.info(msg) # TODO: Need to inspect a bug # if idist.get_rank() == 0: # from ignite.contrib.handlers import ProgressBar # # profiler = BasicTimeProfiler() # profiler.attach(trainer) # # @trainer.on(Events.ITERATION_COMPLETED(every=200)) # def log_profiling(_): # results = profiler.get_results() # profiler.print_results(results) # Setup validation engine metrics = { "accuracy": Accuracy(), } if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU): metrics.update({ "precision": Precision(average=False), "recall": Recall(average=False), }) eval_kwargs = dict( metrics=metrics, prepare_batch=sup_prepare_batch, device=idist.device(), non_blocking=True, ) evaluator = create_supervised_evaluator(model, **eval_kwargs) ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs) def log_results(epoch, max_epochs, metrics, ema_metrics): msg1 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()] ) msg2 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()] ) logger.info( "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2) ) if cta is not None: logger.info("\n" + stats(cta)) @trainer.on( Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) | Events.STARTED | Events.COMPLETED ) def run_evaluation(): evaluator.run(test_loader) ema_evaluator.run(test_loader) log_results( trainer.state.epoch, trainer.state.max_epochs, evaluator.state.metrics, ema_evaluator.state.metrics, ) # setup TB logging if idist.get_rank() == 0: tb_logger = common.setup_tb_logging( output_path, trainer, optimizers=optimizer, evaluators={"validation": evaluator, "ema validation": ema_evaluator}, log_every_iters=15, ) if cfg.online_exp_tracking.wandb: from ignite.contrib.handlers import WandBLogger wb_dir = Path("/tmp/output-fixmatch-wandb") if not wb_dir.exists(): wb_dir.mkdir() _ = WandBLogger( project="fixmatch-pytorch", name=cfg.name, config=cfg, sync_tensorboard=True, dir=wb_dir.as_posix(), reinit=True, ) resume_from = cfg.solver.resume_from if resume_from is not None: resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*")) if len(resume_from) > 0: # get latest checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) @trainer.on(Events.COMPLETED) def release_all_resources(): nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter if idist.get_rank() == 0: tb_logger.close() if unsupervised_train_loader_iter is not None: unsupervised_train_loader_iter = None if cta_probe_loader_iter is not None: cta_probe_loader_iter = None return trainer
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations with_amp = config["with_amp"] scaler = GradScaler(enabled=with_amp) def train_step(engine, batch): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] token_type_ids = batch["token_type_ids"] labels = batch["label"].view(-1, 1) if input_ids.device != device: input_ids = input_ids.to(device, non_blocking=True, dtype=torch.long) attention_mask = attention_mask.to(device, non_blocking=True, dtype=torch.long) token_type_ids = token_type_ids.to(device, non_blocking=True, dtype=torch.long) labels = labels.to(device, non_blocking=True, dtype=torch.float) model.train() with autocast(enabled=with_amp): y_pred = model(input_ids, attention_mask, token_type_ids) loss = criterion(y_pred, labels) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } metric_names = [ "batch loss", ] if config["log_every_iters"] == 0: # Disable logging training metrics: metric_names = None config["log_every_iters"] = 15 common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=utils.get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names, log_every_iters=config["log_every_iters"], with_pbars=not config["with_clearml"], clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists( ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--model", type=str, default='ffn', help="model's name") parser.add_argument("--checkpoint", type=str, required=True, help="checkpoint file path") parser.add_argument("--pilot_version", type=int, choices=[1, 2], default=1) parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument("--seed", type=int, default=43) parser.add_argument("--debug", action='store_true') args = parser.parse_args() # Setup CUDA, GPU & distributed training args.distributed = (args.local_rank != -1) if not args.distributed: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method='env://') args.n_gpu = torch.cuda.device_count() if not args.distributed else 1 args.device = device # Set seed set_seed(args) logger = setup_logger("Testing", distributed_rank=args.local_rank) # Model construction model = getattr(models, args.model)(args) checkpoint_fp = Path(args.checkpoint) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix()) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint) model = model.to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True ) datapath = f'data/Y_{args.pilot_version}.csv' dataY = pd.read_csv(datapath, header=None).values test_dataset = torch.tensor(dataY, dtype=torch.float32) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=False) pred = [] model.eval() for batch in tqdm(test_loader, desc="Runing Testing"): batch = batch.to(device) x_pred = model(batch) x_pred = x_pred > 0.5 pred.append(x_pred.cpu().numpy()) np.concatenate(pred).tofile(f'{os.path.split(args.checkpoint)[0]}/X_pre_{args.pilot_version}.bin') if args.debug: np.ones_like(np.concatenate(pred)).tofile(f'{os.path.split(args.checkpoint)[0]}/X_pre_2.bin')
def run(conf: DictConfig, local_rank=0, distributed=False): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.general.seed) if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.general.gpu) device = torch.device('cuda') loader_args = dict() master_node = rank == 0 if master_node: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args = dict(rank=rank, num_replicas=num_replicas) train_dl = create_train_loader(conf.data, **loader_args) if epoch_length < 1: epoch_length = len(train_dl) metric_names = list(conf.logging.stats) metrics = create_metrics(metric_names, device if distributed else None) G = instantiate(conf.model.G).to(device) D = instantiate(conf.model.D).to(device) G_loss = instantiate(conf.loss.G).to(device) D_loss = instantiate(conf.loss.D).to(device) G_opt = instantiate(conf.optim.G, G.parameters()) D_opt = instantiate(conf.optim.D, D.parameters()) G_ema = None if master_node and conf.G_smoothing.enabled: G_ema = instantiate(conf.model.G) if not conf.G_smoothing.use_cpu: G_ema = G_ema.to(device) G_ema.load_state_dict(G.state_dict()) G_ema.requires_grad_(False) to_save = { 'G': G, 'D': D, 'G_loss': G_loss, 'D_loss': D_loss, 'G_opt': G_opt, 'D_opt': D_opt, 'G_ema': G_ema } if master_node and conf.logging.model: logging.info(G) logging.info(D) if distributed: ddp_kwargs = dict(device_ids=[ local_rank, ], output_device=local_rank) G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs) D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs) train_options = { 'train': dict(conf.train), 'snapshot': dict(conf.snapshots), 'smoothing': dict(conf.G_smoothing), 'distributed': distributed } bs_dl = int(conf.data.loader.batch_size) * num_replicas bs_eff = conf.train.batch_size if bs_eff % bs_dl: raise AttributeError( "Effective batch size should be divisible by data-loader batch size " "multiplied by number of devices in use" ) # until there is no special bs for master node... upd_interval = max(bs_eff // bs_dl, 1) train_options['train']['update_interval'] = upd_interval if epoch_length < len(train_dl): # ideally epoch_length should be tied to the effective batch_size only # and the ignite trainer counts data-loader iterations epoch_length *= upd_interval train_loop, sample_images = create_train_closures(G, D, G_loss, D_loss, G_opt, D_opt, G_ema=G_ema, device=device, options=train_options) trainer = create_trainer(train_loop, metrics, device, num_replicas) to_save['trainer'] = trainer every_iteration = Events.ITERATION_COMPLETED trainer.add_event_handler(every_iteration, TerminateOnNan()) cp = conf.checkpoints pbar = None if master_node: log_freq = conf.logging.iter_freq log_event = Events.ITERATION_COMPLETED(every=log_freq) pbar = ProgressBar(persist=False) trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) trainer.add_event_handler(log_event, log_iter, pbar, log_freq) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch) pbar.attach(trainer, metric_names=metric_names) setup_checkpoints(trainer, to_save, epoch_length, conf) setup_snapshots(trainer, sample_images, conf) if 'load' in cp.keys() and cp.load is not None: if master_node: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) Checkpoint.load_objects(to_load=to_save, checkpoint=torch.load(cp.load, map_location=device)) try: trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback logging.error(traceback.format_exc()) if pbar is not None: pbar.close()