model_saver = RollingSaver( ModelCheckpointSaver( CheckpointSaver(os.path.join(args.train_path, "model.ckpt")), model), keep=args.num_rolling_checkpoints, ) @action def save_ckpt(saver, epoch): saver.save(suffix=epoch) # Set hooks if args.max_epochs and args.max_epochs > 0: trainer.add_hook( EPOCH_START, Hook(GEqThan(trainer.epochs, args.max_epochs), trainer.stop)) model_saver_when = Any( GEqThan(trainer.epochs, args.max_epochs - args.num_rolling_checkpoints), MultipleOf(trainer.epochs, args.save_checkpoint_interval), ) else: model_saver_when = MultipleOf(trainer.epochs, args.save_checkpoint_interval) trainer.add_hook( EPOCH_END, HookCollection( Hook( Highest(engine_wrapper.valid_ap(), key=0, name="Highest gAP"), Action(save_ckpt, saver=highest_gap_saver),
criterion=DortmundBCELoss(), optimizer=optimizer, data_loader=tr_ds_loader, batch_input_fn=ImageFeeder(device=args.gpu, keep_padded_tensors=False, parent_feeder=ItemFeeder('img')), batch_target_fn=VariableFeeder(device=args.gpu, parent_feeder=PHOCFeeder( syms=syms, levels=args.phoc_levels, parent_feeder=ItemFeeder('txt'))), progress_bar='Train' if args.show_progress_bar else False) trainer.iterations_per_update = args.iterations_per_update trainer.add_hook(EPOCH_END, Evaluate(model, qr_ds_loader, wd_ds_loader, args.gpu)) if args.max_epochs and args.max_epochs > 0: trainer.add_hook( EPOCH_START, Hook(GEqThan(trainer.epochs, args.max_epochs), trainer.stop)) if args.continue_epoch: trainer._epochs = args.continue_epoch # Launch training trainer.run() # Save model parameters after training torch.save(model.state_dict(), os.path.join(args.train_path, 'model.ckpt'))
model, os.path.join(args.train_path, 'model.ckpt-highest-valid-map')) @action def save_ckpt(epoch): prefix = os.path.join(args.train_path, 'model.ckpt') torch.save(model.state_dict(), '{}-{}'.format(prefix, epoch)) # Set hooks trainer.add_hook( EPOCH_END, HookList( Hook(Highest(engine_wrapper.valid_ap(), key=0, name='Highest gAP'), highest_gap_saver), Hook(Highest(engine_wrapper.valid_ap(), key=1, name='Highest mAP'), highest_map_saver))) if args.max_epochs and args.max_epochs > 0: trainer.add_hook( EPOCH_START, Hook(GEqThan(trainer.epochs, args.max_epochs), trainer.stop)) # Save last 10 epochs trainer.add_hook( EPOCH_END, Hook(GEqThan(trainer.epochs, args.max_epochs - 10), save_ckpt)) if args.continue_epoch: trainer._epochs = args.continue_epoch # Launch training engine_wrapper.run()
model, os.path.join(args.train_path, 'model.ckpt-lowest-valid-wer')) @action def save_ckpt(epoch): prefix = os.path.join(args.train_path, 'model.ckpt') torch.save(model.state_dict(), '{}-{}'.format(prefix, epoch)) # Set hooks trainer.add_hook(EPOCH_END, HookList( Hook(Lowest(engine_wrapper.valid_cer(), name='Lowest CER'), lowest_cer_saver), Hook(Lowest(engine_wrapper.valid_wer(), name='Lowest WER'), lowest_wer_saver))) if args.max_epochs and args.max_epochs > 0: trainer.add_hook(EPOCH_START, Hook(GEqThan(trainer.epochs, args.max_epochs), trainer.stop)) # Save last 10 epochs trainer.add_hook(EPOCH_END, Hook(GEqThan(trainer.epochs, args.max_epochs - 10), save_ckpt)) if args.continue_epoch: trainer._epochs = args.continue_epoch # Launch training engine_wrapper.run()