コード例 #1
0
ファイル: train.py プロジェクト: whatseven/python
def main(v_cfg: DictConfig):
    print(OmegaConf.to_yaml(v_cfg))
    seed_everything(0)
    # set_start_method('spawn')

    early_stop_callback = EarlyStopping(
        patience=100,
        monitor="Validation Loss"
    )

    model_check_point = ModelCheckpoint(
        monitor='Validation Loss',
        save_top_k=3,
        save_last=True
    )

    trainer = Trainer(gpus=v_cfg["trainer"].gpu, weights_summary=None,
                      accelerator="ddp" if v_cfg["trainer"].gpu > 1 else None,
                      # early_stop_callback=early_stop_callback,
                      callbacks=[model_check_point],
                      auto_lr_find="learning_rate" if v_cfg["trainer"].auto_lr_find else False,
                      max_epochs=500,
                      gradient_clip_val=0.1,
                      check_val_every_n_epoch=3
                      )

    model = Mono_det_3d(v_cfg)
    if v_cfg["trainer"].resume_from_checkpoint is not None:
        model.load_state_dict(torch.load(v_cfg["trainer"].resume_from_checkpoint)["state_dict"], strict=True)
    if v_cfg["trainer"].auto_lr_find:
        trainer.tune(model)
        print(model.learning_rate)
    if v_cfg["trainer"].evaluate:
        trainer.test(model)
    else:
        trainer.fit(model)
コード例 #2
0
def test_early_stopping_no_val_step(tmpdir):
    """Test that early stopping callback falls back to training metrics when no validation defined."""
    class CurrentModel(EvalModelTemplate):
        def training_step(self, *args, **kwargs):
            output = super().training_step(*args, **kwargs)
            output.update({'my_train_metric':
                           output['loss']})  # could be anything else
            return output

    model = CurrentModel()
    model.validation_step = None
    model.val_dataloader = None

    stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
    trainer = Trainer(
        default_root_dir=tmpdir,
        early_stop_callback=stopping,
        overfit_batches=0.20,
        max_epochs=2,
    )
    result = trainer.fit(model)

    assert result == 1, 'training failed to complete'
    assert trainer.current_epoch <= trainer.max_epochs
コード例 #3
0
def test_early_stopping_cpu_model():
    """
    Test each of the trainer options
    :return:
    """
    reset_seed()

    stopping = EarlyStopping(monitor='val_loss')
    trainer_options = dict(early_stop_callback=stopping,
                           gradient_clip_val=1.0,
                           overfit_pct=0.20,
                           track_grad_norm=2,
                           print_nan_grads=True,
                           show_progress_bar=True,
                           logger=get_test_tube_logger(),
                           train_percent_check=0.1,
                           val_percent_check=0.1)

    model, hparams = get_model()
    run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)

    # test freeze on cpu
    model.freeze()
    model.unfreeze()
コード例 #4
0
def main(hparams):
    # load model
    model = MyModel(hparams)

    # init experiment
    exp = Experiment(name=hparams.experiment_name,
                     save_dir=hparams.test_tube_save_path,
                     autosave=False,
                     description='baseline attn interval')

    exp.argparse(hparams)
    exp.save()

    # define callbackes
    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name,
                                        exp.version)
    early_stop = EarlyStopping(monitor='val_loss',
                               patience=5,
                               verbose=True,
                               mode='min')

    checkpoint = ModelCheckpoint(filepath=model_save_path,
                                 save_best_only=True,
                                 verbose=True,
                                 monitor='pr',
                                 mode='max')

    # init trainer
    trainer = Trainer(experiment=exp,
                      checkpoint_callback=checkpoint,
                      early_stop_callback=early_stop,
                      gpus=hparams.gpus,
                      val_check_interval=1)

    # start training
    trainer.fit(model)
コード例 #5
0
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
    if min_steps:
        assert limit_train_batches < min_steps

    class TestModel(BoringModel):
        def training_step(self, batch, batch_idx):
            self.log("foo", batch_idx)
            return super().training_step(batch, batch_idx)

    es_callback = EarlyStopping("foo")
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=es_callback,
        limit_val_batches=0,
        limit_train_batches=limit_train_batches,
        min_epochs=min_epochs,
        min_steps=min_steps,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
        enable_model_summary=False,
    )
    model = TestModel()

    expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
    # trigger early stopping directly after the first epoch
    side_effect = [(True, "")] * expected_epochs
    with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
        trainer.fit(model)

    assert trainer.should_stop
    # epochs continue until min steps are reached
    assert trainer.current_epoch == expected_epochs
    # steps continue until min steps are reached AND the epoch is exhausted
    # stopping mid-epoch is not supported
    assert trainer.global_step == limit_train_batches * expected_epochs
コード例 #6
0
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):

    losses = [4, 3, stop_value, 2, 1]
    expected_stop_epoch = 2

    class CurrentModel(BoringModel):
        def validation_epoch_end(self, outputs):
            val_loss = losses[self.current_epoch]
            self.log('val_loss', val_loss)

    model = CurrentModel()
    early_stopping = EarlyStopping(
        monitor='val_loss',
        check_finite=True,
    )
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stopping],
        overfit_batches=0.20,
        max_epochs=10,
    )
    trainer.fit(model)
    assert trainer.current_epoch == expected_stop_epoch
    assert early_stopping.stopped_epoch == expected_stop_epoch
コード例 #7
0
def test_early_stopping_patience(tmpdir, loss_values: list, patience: int,
                                 expected_stop_epoch: int):
    """Test to ensure that early stopping is not triggered before patience is exhausted."""
    class ModelOverrideValidationReturn(BoringModel):
        validation_return_values = torch.tensor(loss_values)

        def validation_epoch_end(self, outputs):
            loss = self.validation_return_values[self.current_epoch]
            self.log("test_val_loss", loss)

    model = ModelOverrideValidationReturn()
    early_stop_callback = EarlyStopping(monitor="test_val_loss",
                                        patience=patience,
                                        verbose=True)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[early_stop_callback],
        val_check_interval=1.0,
        num_sanity_val_steps=0,
        max_epochs=10,
        progress_bar_refresh_rate=0,
    )
    trainer.fit(model)
    assert trainer.current_epoch == expected_stop_epoch
コード例 #8
0
ファイル: dvrl_modules.py プロジェクト: isaurabh19/dvrl
    def __init__(self, hparams, dve_model: RLDataValueEstimator,
                 prediction_model: DVRLPredictionModel, val_dataloader,
                 test_dataloader, val_split):
        """
        Implements the DVRL framework.
        :param hparams: this should be a dict, NameSpace or OmegaConf object that implements hyperparameter-storage
        :param prediction_model: this is the core predictor model, this is passed separately since DVRL is agnostic
        to the prediction model.

        ** Note **: In this iteration, the prediction model is constrained to be a torch module and hence shallow models
         won't work
        """
        super().__init__()
        # saving hparams is deprecated, if does not work, follow this -
        # https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html#lightningmodule-hyperparameters

        self.hparams = hparams
        self.dve = dve_model
        self.prediction_model = prediction_model
        self.validation_dataloader = val_dataloader
        self.baseline_delta = 0.0
        self.val_split = val_split
        self.exploration_threshold = self.hparams.exploration_threshold

        self.val_model = copy.deepcopy(self.prediction_model)
        trainer = Trainer(gpus=1,
                          max_epochs=10,
                          callbacks=[EarlyStopping(monitor='loss')])
        trainer.fit(model=self.val_model, train_dataloader=val_dataloader)
        self.val_model.eval()
        self.val_model.requires_grad_(False)
        self.dve.set_val_model(self.val_model)
        self.validation_performance = None

        self.init_test_dataloader = test_dataloader
        self.test_acc = pl.metrics.Accuracy(compute_on_step=False)
def main(args):
    # pick model according to args
    if args.early_stop_callback:
        early_stop_callback = EarlyStopping(monitor='val_loss',
                                            patience=30,
                                            strict=True,
                                            verbose=False,
                                            mode='min')
    else:
        early_stop_callback = False

    checkpoint_callback = ModelCheckpoint(filepath=None,
                                          monitor='F1_score',
                                          save_top_k=1,
                                          mode='max')

    lr_logger = LearningRateLogger()

    if args.test:
        pretrained_model = COVID_Xray_Sys.load_from_checkpoint(args.model_path)
        trainer = Trainer(gpus=args.gpus)
        trainer.test(pretrained_model)
        return 0
        # pretrained_model.freeze()
        # y_hat = pretrained_model(x)

    Sys = COVID_Xray_Sys(hparams=args)
    trainer = Trainer(early_stop_callback=early_stop_callback,
                      checkpoint_callback=checkpoint_callback,
                      callbacks=[lr_logger],
                      gpus=args.gpus,
                      default_save_path='../../results/logs/{}'.format(
                          os.path.basename(__file__)[:-3]),
                      max_epochs=2000)

    trainer.fit(Sys)
コード例 #10
0
def test_early_stopping_cpu_model(tmpdir):
    """Test each of the trainer options."""
    tutils.reset_seed()

    stopping = EarlyStopping(monitor='val_loss', min_delta=0.1)
    trainer_options = dict(
        default_save_path=tmpdir,
        early_stop_callback=stopping,
        gradient_clip_val=1.0,
        overfit_pct=0.20,
        track_grad_norm=2,
        print_nan_grads=True,
        show_progress_bar=True,
        logger=tutils.get_test_tube_logger(tmpdir),
        train_percent_check=0.1,
        val_percent_check=0.1,
    )

    model, hparams = tutils.get_model()
    tutils.run_model_test(trainer_options, model, on_gpu=False)

    # test freeze on cpu
    model.freeze()
    model.unfreeze()
コード例 #11
0
def test_model_tpu_early_stop(tmpdir):
    """Test if single TPU core training works"""

    # todo: Test on 8 cores - hanging.

    class CustomBoringModel(BoringModel):

        def validation_step(self, *args, **kwargs):
            out = super().validation_step(*args, **kwargs)
            self.log('val_loss', out['x'])
            return out

    tutils.reset_seed()
    model = CustomBoringModel()
    trainer = Trainer(
        callbacks=[EarlyStopping(monitor='val_loss')],
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        tpu_cores=[1],
    )
    trainer.fit(model)
コード例 #12
0
def test_trainer_min_steps_and_epochs(tmpdir):
    """Verify model trains according to specified min steps"""
    model, trainer_options, num_train_samples = _init_steps_model()

    # define callback for stopping the model and default epochs
    trainer_options.update(
        default_root_dir=tmpdir,
        early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0),
        val_check_interval=2,
        min_epochs=1,
        max_epochs=5
    )

    # define less min steps than 1 epoch
    trainer_options['min_steps'] = math.floor(num_train_samples / 2)

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result == 1, "Training did not complete"

    # check model ran for at least min_epochs
    assert trainer.global_step >= num_train_samples and \
        trainer.current_epoch > 0, "Model did not train for at least min_epochs"

    # define less epochs than min_steps
    trainer_options['min_steps'] = math.floor(num_train_samples * 1.5)

    # fit model
    trainer = Trainer(**trainer_options)
    result = trainer.fit(model)
    assert result == 1, "Training did not complete"

    # check model ran for at least num_train_samples*1.5
    assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \
        trainer.current_epoch > 0, "Model did not train for at least min_steps"
コード例 #13
0
def tuning(config=None,
           MODEL=None,
           pose_autoencoder=None,
           cost_dim=None,
           phase_dim=None,
           input_slices=None,
           output_slices=None,
           train_set=None,
           val_set=None,
           num_epochs=300,
           model_name="model"):
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=1,
        logger=TensorBoardLogger(save_dir="logs/",
                                 name=model_name,
                                 version="0.0"),
        progress_bar_refresh_rate=5,
        callbacks=[
            TuneReportCallback({
                "loss": "avg_val_loss",
            }, on="validation_end"),
            EarlyStopping(monitor="avg_val_loss")
        ],
    )
    model = MODEL(config=config,
                  pose_autoencoder=pose_autoencoder,
                  cost_input_dimension=cost_dim,
                  phase_dim=pp_dim,
                  input_slicers=input_slices,
                  output_slicers=output_slices,
                  train_set=train_set,
                  val_set=val_set,
                  name=model_name)

    trainer.fit(model)
コード例 #14
0
def train():
    from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
    early_stop_callback = EarlyStopping(monitor='val_acc',
                                        patience=10,
                                        verbose=True,
                                        mode='max')
    checkpoint_callback = ModelCheckpoint(save_top_k=True,
                                          verbose=True,
                                          monitor='val_acc',
                                          mode='max')

    model = TemporalCNN(num_classes=len(CATEGORIES),
                        n_consecutive_frames=n_frames,
                        lr=0.001)
    trainer = Trainer(
        gpus=1,
        max_epochs=1,
        callbacks=[checkpoint_callback, early_stop_callback],
        num_sanity_val_steps=0,
    )
    trainer.fit(model,
                datamodule=_DataModule(data_dir,
                                       n_consecutive_frames=n_frames))
    trainer.test()
コード例 #15
0
def main():
    system = configure_system(
        hyperparameter_defaults["system"])(hyperparameter_defaults)
    logger = TensorBoardLogger(
        'experiments_logs',
        name=str(hyperparameter_defaults['system']) + "_" +
        str(system.model.__class__.__name__) + "_" +
        str(hyperparameter_defaults['criterion']) + "_" +
        str(hyperparameter_defaults['scheduler']))

    early_stop = EarlyStopping(monitor="valid_iou",
                               mode="max",
                               verbose=True,
                               patience=hyperparameter_defaults["patience"])
    model_checkpoint = ModelCheckpoint(
        monitor="valid_iou",
        mode="max",
        verbose=True,
        filename='Model-{epoch:02d}-{valid_iou:.5f}',
        save_top_k=3,
        save_last=True)
    trainer = pl.Trainer(
        gpus=[0, 1],
        plugins=DDPPlugin(find_unused_parameters=True),
        max_epochs=hyperparameter_defaults['epochs'],
        logger=logger,
        check_val_every_n_epoch=1,
        accelerator='ddp',
        callbacks=[early_stop, model_checkpoint],
        num_sanity_val_steps=0,
        limit_train_batches=1.0,
        deterministic=True,
    )

    trainer.fit(system)
    trainer.test(system)
コード例 #16
0
def main(conf):
    train_set = LibriMix(csv_dir=conf['data']['train_dir'],
                         task=conf['data']['task'],
                         sample_rate=conf['data']['sample_rate'],
                         n_src=conf['data']['n_src'],
                         segment=conf['data']['segment'])

    val_set = LibriMix(csv_dir=conf['data']['valid_dir'],
                       task=conf['data']['task'],
                       sample_rate=conf['data']['sample_rate'],
                       n_src=conf['data']['n_src'],
                       segment=conf['data']['segment'])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)

    val_loader = DataLoader(val_set, shuffle=True,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    conf['masknet'].update({'n_src': conf['data']['n_src']})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    system = System(model=model, loss_func=loss_func, optimizer=optimizer,
                    train_loader=train_loader, val_loader=val_loader,
                    scheduler=scheduler, config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=5, verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None
    trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         early_stop_callback=early_stopping,
                         default_save_path=exp_dir,
                         gpus=conf['main_args']['gpus'],
                         distributed_backend='dp',
                         train_percent_check=1.0,  # Useful for fast experiment
                         gradient_clip_val=5.)
    trainer.fit(system)

    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(checkpoint.best_k_models, f, indent=0)
コード例 #17
0
ファイル: train.py プロジェクト: zmolikova/asteroid
def main(conf, args):
    # Set seed for random
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # create output dir if not exist
    exp_dir = Path(args.output)
    exp_dir.mkdir(parents=True, exist_ok=True)

    # Load Datasets
    train_dataset, valid_dataset = dataloader.load_datasets(parser, args)
    dataloader_kwargs = ({
        "num_workers": args.num_workers,
        "pin_memory": True
    } if torch.cuda.is_available() else {})
    train_sampler = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **dataloader_kwargs)
    valid_sampler = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=1,
                                                **dataloader_kwargs)

    # Define model and optimizer
    if args.pretrained is not None:
        scaler_mean = None
        scaler_std = None
    else:
        scaler_mean, scaler_std = get_statistics(args, train_dataset)

    max_bin = bandwidth_to_max_bin(train_dataset.sample_rate, args.in_chan,
                                   args.bandwidth)

    x_unmix = XUMX(
        window_length=args.window_length,
        input_mean=scaler_mean,
        input_scale=scaler_std,
        nb_channels=args.nb_channels,
        hidden_size=args.hidden_size,
        in_chan=args.in_chan,
        n_hop=args.nhop,
        sources=args.sources,
        max_bin=max_bin,
        bidirectional=args.bidirectional,
        sample_rate=train_dataset.sample_rate,
        spec_power=args.spec_power,
        return_time_signals=True if args.loss_use_multidomain else False,
    )

    optimizer = make_optimizer(x_unmix.parameters(),
                               lr=args.lr,
                               optimizer="adam",
                               weight_decay=args.weight_decay)

    # Define scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=args.lr_decay_gamma,
        patience=args.lr_decay_patience,
        cooldown=10)

    # Save config
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    es = EarlyStopping(monitor="val_loss",
                       mode="min",
                       patience=args.patience,
                       verbose=True)

    # Define Loss function.
    loss_func = MultiDomainLoss(
        window_length=args.window_length,
        in_chan=args.in_chan,
        n_hop=args.nhop,
        spec_power=args.spec_power,
        nb_channels=args.nb_channels,
        loss_combine_sources=args.loss_combine_sources,
        loss_use_multidomain=args.loss_use_multidomain,
        mix_coef=args.mix_coef,
    )
    system = XUMXManager(
        model=x_unmix,
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_sampler,
        val_loader=valid_sampler,
        scheduler=scheduler,
        config=conf,
        val_dur=args.val_dur,
    )

    # Define callbacks
    callbacks = []
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    callbacks.append(checkpoint)
    callbacks.append(es)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    distributed_backend = "ddp" if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=args.epochs,
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        limit_train_batches=1.0,  # Useful for fast experiment
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_dataset.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
コード例 #18
0
    seed_reproducer(2020)

    # Init Hyperparameters
    hparams = init_hparams()

    # init logger
    logger = init_logger("kun_out", log_dir=hparams.log_dir)

    # Load data
    data, test_data = load_data(logger)

    # Generate transforms
    transforms = generate_transforms(hparams.image_size)

    early_stop_callback = EarlyStopping(monitor="val_roc_auc",
                                        patience=10,
                                        mode="max",
                                        verbose=True)

    # Instance Model, Trainer and train model
    model = CoolSystem(hparams)
    trainer = pl.Trainer(
        gpus=hparams.gpus,
        min_epochs=70,
        max_epochs=hparams.max_epochs,
        early_stop_callback=early_stop_callback,
        progress_bar_refresh_rate=0,
        precision=hparams.precision,
        num_sanity_val_steps=0,
        profiler=False,
        weights_summary=None,
        use_dp=True,
コード例 #19
0
if __name__ == "__main__":
    # Gets and split data
    x, y = load_breast_cancer(return_X_y=True)
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1)

    # From numpy to torch tensors
    x_train = torch.from_numpy(x_train).type(torch.FloatTensor)
    y_train = torch.from_numpy(y_train).type(torch.FloatTensor)

    # From numpy to torch tensors
    x_val = torch.from_numpy(x_val).type(torch.FloatTensor)
    y_val = torch.from_numpy(y_val).type(torch.FloatTensor)

    # Implements Dataset and DataLoader
    train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=5)

    # Implements Dataset and DataLoader
    val_dataset = torch.utils.data.TensorDataset(x_val, y_val)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=5)

    # Init Neural Net model
    nn = NeuralNet(learning_rate=0.001)

    early_stopping = EarlyStopping('val_acc_epoch')

    # Init Trainer
    trainer = pl.Trainer(max_epochs=10, callbacks=[early_stopping])

    # Train
    trainer.fit(nn, train_dataloader, val_dataloader)
コード例 #20
0
def main():
    # Retrieve Config and and custom base parameter
    cfg = get_cfg()
    add_custom_param(cfg)
    cfg.merge_from_file("config.yaml")

    logging.getLogger("pytorch_lightning").setLevel(logging.INFO)
    logger = logging.getLogger("pytorch_lightning.core")
    if not os.path.exists(cfg.CALLBACKS.CHECKPOINT_DIR):
        os.makedirs(cfg.CALLBACKS.CHECKPOINT_DIR)
    logger.addHandler(
        logging.FileHandler(
            os.path.join(cfg.CALLBACKS.CHECKPOINT_DIR, "core.log")))
    with open("config.yaml") as file:
        logger.info(file.read())
    # Initialise Custom storage to avoid error when using detectron 2
    _CURRENT_STORAGE_STACK.append(EventStorage())

    # Create transforms
    transform_train = A.Compose(
        [
            A.Resize(height=cfg.TRANSFORM.RESIZE.HEIGHT,
                     width=cfg.TRANSFORM.RESIZE.WIDTH),
            A.RandomCrop(height=cfg.TRANSFORM.RANDOMCROP.HEIGHT,
                         width=cfg.TRANSFORM.RANDOMCROP.WIDTH),
            A.HorizontalFlip(p=cfg.TRANSFORM.HFLIP.PROB),
            A.Normalize(mean=cfg.TRANSFORM.NORMALIZE.MEAN,
                        std=cfg.TRANSFORM.NORMALIZE.STD),
            # A.RandomScale(scale_limit=[0.5, 2]),
        ],
        bbox_params=A.BboxParams(format='coco', label_fields=['class_labels']))

    transform_valid = A.Compose([
        A.Resize(height=512, width=1024),
        A.Normalize(mean=cfg.TRANSFORM.NORMALIZE.MEAN,
                    std=cfg.TRANSFORM.NORMALIZE.STD),
    ],
                                bbox_params=A.BboxParams(
                                    format='coco',
                                    label_fields=['class_labels']))

    # Create Dataset
    train_dataset = PanopticDataset(cfg.TRAIN_JSON,
                                    cfg.DATASET_PATH,
                                    'train',
                                    transform=transform_train)

    valid_dataset = PanopticDataset(cfg.VALID_JSON,
                                    cfg.DATASET_PATH,
                                    'val',
                                    transform=transform_valid)

    # Create Data Loader
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.BATCH_SIZE,
                              shuffle=True,
                              collate_fn=collate_fn,
                              pin_memory=False,
                              num_workers=4)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg.BATCH_SIZE,
                              shuffle=False,
                              collate_fn=collate_fn,
                              pin_memory=False,
                              num_workers=4)

    # Create model or load a checkpoint
    if os.path.exists(cfg.CHECKPOINT_PATH):
        print('""""""""""""""""""""""""""""""""""""""""""""""')
        print("Loading model from {}".format(cfg.CHECKPOINT_PATH))
        print('""""""""""""""""""""""""""""""""""""""""""""""')
        efficientps = EffificientPS.load_from_checkpoint(
            cfg=cfg, checkpoint_path=cfg.CHECKPOINT_PATH)
    else:
        print('""""""""""""""""""""""""""""""""""""""""""""""')
        print("Creating a new model")
        print('""""""""""""""""""""""""""""""""""""""""""""""')
        efficientps = EffificientPS(cfg)
        cfg.CHECKPOINT_PATH = None

    logger.info(efficientps.print)
    # Callbacks / Hooks
    early_stopping = EarlyStopping('PQ', patience=5, mode='max')
    checkpoint = ModelCheckpoint(
        monitor='PQ',
        mode='max',
        dirpath=cfg.CALLBACKS.CHECKPOINT_DIR,
        save_last=True,
        verbose=True,
    )

    # Create a pytorch lighting trainer
    trainer = pl.Trainer(
        # weights_summary='full',
        gpus=1,
        num_sanity_val_steps=0,
        # fast_dev_run=True,
        callbacks=[early_stopping, checkpoint],
        precision=cfg.PRECISION,
        resume_from_checkpoint=cfg.CHECKPOINT_PATH,
        gradient_clip_val=15,
        accumulate_grad_batches=cfg.SOLVER.ACCUMULATE_GRAD)
    logger.addHandler(logging.StreamHandler())
    trainer.fit(efficientps, train_loader, val_dataloaders=valid_loader)
コード例 #21
0
def main(conf):
    # set seed
    pl.seed_everything(conf.seed)

    # load datamodule
    data = HotpotDataModule(conf=conf)

    # load model
    model_name = conf.model.name + ('_dataset' if conf.model.dataset else '')
    if conf.training.train:
        if conf.training.from_checkpoint:
            model = models[model_name].load_from_checkpoint(
                checkpoint_path=os.path.join(
                    os.path.split(hydra.utils.get_original_cwd())[0],
                    'outputs', conf.training.from_checkpoint))
        else:
            model = models[model_name](conf=conf)
    else:
        model = models[model_name].load_from_checkpoint(
            checkpoint_path=os.path.join(
                os.path.split(hydra.utils.get_original_cwd())[0], 'outputs',
                conf.testing.model_path))

    # TRAINER
    callbacks = []

    # checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=conf.training.model_checkpoint.dirpath,
        filename=conf.training.model_checkpoint.filename,
        monitor=conf.training.model_checkpoint.monitor,
        save_last=conf.training.model_checkpoint.save_last,
        save_top_k=conf.training.model_checkpoint.save_top_k)
    callbacks.append(checkpoint_callback)

    # early stop callback
    if conf.training.early_stopping.early_stop:
        early_stop_callback = EarlyStopping(
            monitor=conf.training.early_stopping.monitor,
            patience=conf.training.early_stopping.patience,
            mode=conf.training.early_stopping.mode,
        )
        callbacks.append(early_stop_callback)

    # logger
    wandb_logger = WandbLogger(name=model_name,
                               project='neural-question-generation')

    # trainer
    trainer = pl.Trainer(
        accumulate_grad_batches=conf.training.grad_cum,
        callbacks=callbacks,
        default_root_dir='.',
        deterministic=True,
        fast_dev_run=conf.debug,
        flush_logs_every_n_steps=10,
        gpus=(1 if torch.cuda.is_available() else 0),
        logger=wandb_logger,
        log_every_n_steps=100,
        max_epochs=conf.training.max_epochs,
        num_sanity_val_steps=0,
        reload_dataloaders_every_epoch=True,
        # val_check_interval=0.05,
    )

    # TODO: tune

    # train
    if conf.training.train:
        trainer.fit(model=model, datamodule=data)

    # test
    if conf.testing.test:
        trainer.test(model=model, datamodule=data)
        if model_name != 'bert_clf' and model_name != 'bert_sum' and model_name != 'bert_clf+bart_dataset':
            results = evaluation_metrics(conf)
            wandb_logger.log_metrics(results)
コード例 #22
0
    vae = vae.float()
    vae = vae.cuda()
    # vae.load_state_dict(torch.load(f'{checkpoint_path}/weights.pt'))

    torch.cuda.empty_cache()
    version = datetime.strftime(datetime.fromtimestamp(seed),
                                '%Y-%m-%d..%H.%M.%S')
    logger = TensorBoardLogger(checkpoint_path, version=version)
    checkpoint = ModelCheckpoint(filepath=checkpoint_path,
                                 save_top_k=1,
                                 verbose=True,
                                 monitor='loss',
                                 mode='min')
    early_stop = EarlyStopping(
        monitor='loss',
        patience=stgs.VAE_HPARAMS['early_stop_patience'],
        verbose=True,
        mode='min')
    max_steps = stgs.VAE_HPARAMS['max_steps']
    # kld loss annealing also depends on max length (or max epochs?)
    vae.get_data_generator(max(min_depth, max_depth - 1), max_depth, seed=seed)

    trainer = pl.Trainer(gpus=-1,
                         val_check_interval=9999,
                         early_stop_callback=None,
                         distributed_backend=None,
                         logger=logger,
                         max_steps=max_steps,
                         max_epochs=max_steps,
                         checkpoint_callback=checkpoint,
                         weights_save_path=checkpoint_path)
コード例 #23
0
    #    print(self.retriever.predict('I am beautiful lady?', ['You are a pretty girl',
    #                                               'apple is tasty',
    #                                               'He is a handsome boy'], True))


if __name__ == '__main__':
    encoder_question = BertEncoder(bert_question, max_question_len_global)
    encoder_paragarph = BertEncoder(bert_paragraph, max_paragraph_len_global)
    ret = Retriver(encoder_question, encoder_paragarph, tokenizer)
    os.makedirs('out', exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        filepath='out/{epoch}-{val_loss:.2f}-{val_acc:.2f}',
        save_top_k=1,
        verbose=True,
        monitor='val_acc',
        mode='max')

    early_stopping = EarlyStopping('val_acc', mode='max')

    trainer = pl.Trainer(
        gpus=1,
        # distributed_backend='dp',
        val_check_interval=0.1,
        min_epochs=1,
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stopping)

    ret_trainee = RetriverTrainer(ret)

    trainer.fit(ret_trainee)
コード例 #24
0
ファイル: train.py プロジェクト: wyn314/asteroid
def main(conf):
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    # Update number of source values (It depends on the task)
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer
    model = ConvTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    # Save best model (next PL version will make this easier)
    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
コード例 #25
0
def main(argv):
    if not os.path.exists(FLAGS.logs_dir):
        os.makedirs(FLAGS.logs_dir)

    set_seed(FLAGS.seed)
    id2class, intent_examples = read_nlu_data()

    if FLAGS.do_train:
        if not os.path.exists(FLAGS.output_dir):
            os.makedirs(FLAGS.output_dir)

        model = NluClassifier(id2class, intent_examples)

        early_stop_callback = EarlyStopping(
            monitor=FLAGS.monitor,
            min_delta=0.0,
            patience=FLAGS.patience,
            verbose=True,
            mode=FLAGS.metric_mode,
        )

        checkpoint_callback = ModelCheckpoint(filepath=FLAGS.output_dir,
                                              save_top_k=3,
                                              monitor=FLAGS.monitor,
                                              mode=FLAGS.metric_mode,
                                              prefix='nlu_leam_')

        trainer = pl.Trainer(
            default_root_dir='logs',
            gpus=(FLAGS.gpus if torch.cuda.is_available() else 0),
            distributed_backend='dp',
            max_epochs=FLAGS.epochs,
            fast_dev_run=FLAGS.debug,
            logger=pl.loggers.TensorBoardLogger('logs/',
                                                name='nlu_leam',
                                                version=0),
            checkpoint_callback=checkpoint_callback,
            early_stop_callback=early_stop_callback)

        trainer.fit(model)

    if FLAGS.do_predict:
        from sanic import Sanic, response
        server = Sanic()

        checkpoints = list(
            sorted(
                glob(os.path.join(FLAGS.output_dir, "nlu_leam_*.ckpt"),
                     recursive=True)))
        model = NluClassifier.load_from_checkpoint(
            checkpoint_path=checkpoints[-1],
            id2class=id2class,
            intent_examples=intent_examples)
        model.eval()
        model.freeze()

        @server.route("/parse", methods=['POST'])
        async def parse(request):
            texts = request.json
            prediction = model.predict(texts)
            return response.json(prediction)

        server.run(host="0.0.0.0", port=5000, debug=True)
コード例 #26
0
def get_trainer(dataset,
        version=1,
        save_dir='./checkpoints',
        name='jnkname',
        auto_lr=False,
        batchsize=1000,
        earlystopping=True,
        earlystoppingpatience=10,
        max_epochs=150,
        num_workers=1,
        gradient_clip_val=0,
        seed=None):
    """
    Returns a pytorch lightning trainer and splits the training set into "train" and "valid"
    """
    from torch.utils.data import Dataset, DataLoader, random_split
    from pytorch_lightning import Trainer, seed_everything
    from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
    from pytorch_lightning.loggers import TestTubeLogger
    from pathlib import Path

    
    save_dir = Path(save_dir)
    n_val = np.floor(len(dataset)/5).astype(int)
    n_train = (len(dataset)-n_val).astype(int)

    gd_train, gd_val = random_split(dataset, lengths=[n_train, n_val])

    # build dataloaders
    train_dl = DataLoader(gd_train, batch_size=batchsize, num_workers=num_workers, pin_memory=True)
    valid_dl = DataLoader(gd_val, batch_size=batchsize, num_workers=num_workers, pin_memory=True)

    # Train
    early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.0, patience=earlystoppingpatience)
    # checkpoint_callback = ModelCheckpoint(monitor='val_loss')

    logger = TestTubeLogger(
        save_dir=save_dir,
        name=name,
        version=version  # fixed to one to ensure checkpoint load
    )

    # ckpt_folder = save_dir / sessid / 'version_{}'.format(version) / 'checkpoints'
    if earlystopping:
        trainer = Trainer(gpus=1, callbacks=[early_stop_callback],
            logger=logger,
            deterministic=False,
            gradient_clip_val=gradient_clip_val,
            accumulate_grad_batches=1,
            progress_bar_refresh_rate=20,
            max_epochs=1000,
            auto_lr_find=auto_lr)
    else:
        trainer = Trainer(gpus=1,
            logger=logger,
            deterministic=False,
            gradient_clip_val=gradient_clip_val,
            accumulate_grad_batches=1,
            progress_bar_refresh_rate=20,
            max_epochs=max_epochs,
            auto_lr_find=auto_lr)

    if seed:
        seed_everything(seed)

    return trainer, train_dl, valid_dl
コード例 #27
0
ファイル: train.py プロジェクト: fdroessler/pytorch_tempest
def run(cfg: DictConfig) -> None:
    """
    Run pytorch-lightning model

    Args:
        new_dir:
        cfg: hydra config

    """
    set_seed(cfg.training.seed)
    run_name = os.path.basename(os.getcwd())
    hparams = flatten_omegaconf(cfg)

    cfg.callbacks.model_checkpoint.params.filepath = os.getcwd(
    ) + cfg.callbacks.model_checkpoint.params.filepath
    callbacks = []
    for callback in cfg.callbacks.other_callbacks:
        if callback.params:
            callback_instance = load_obj(
                callback.class_name)(**callback.params)
        else:
            callback_instance = load_obj(callback.class_name)()
        callbacks.append(callback_instance)

    loggers = []
    if cfg.logging.log:
        for logger in cfg.logging.loggers:
            if 'experiment_name' in logger.params.keys():
                logger.params['experiment_name'] = run_name
            loggers.append(load_obj(logger.class_name)(**logger.params))

    callbacks.append(EarlyStopping(**cfg.callbacks.early_stopping.params))

    trainer = pl.Trainer(
        logger=loggers,
        # early_stop_callback=EarlyStopping(**cfg.callbacks.early_stopping.params),
        checkpoint_callback=ModelCheckpoint(
            **cfg.callbacks.model_checkpoint.params),
        callbacks=callbacks,
        **cfg.trainer,
    )

    model = load_obj(cfg.training.lightning_module_name)(hparams=hparams,
                                                         cfg=cfg)
    dm = load_obj(cfg.datamodule.data_module_name)(hparams=hparams, cfg=cfg)
    trainer.fit(model, dm)

    if cfg.general.save_pytorch_model:
        if cfg.general.save_best:
            best_path = trainer.checkpoint_callback.best_model_path  # type: ignore
            # extract file name without folder
            save_name = os.path.basename(os.path.normpath(best_path))
            model = model.load_from_checkpoint(best_path,
                                               hparams=hparams,
                                               cfg=cfg,
                                               strict=False)
            model_name = f'saved_models/{save_name}'.replace('.ckpt', '.pth')
            torch.save(model.model.state_dict(), model_name)
        else:
            os.makedirs('saved_models', exist_ok=True)
            model_name = 'saved_models/last.pth'
            torch.save(model.model.state_dict(), model_name)

    if cfg.general.convert_to_jit:
        convert_to_jit(model, save_name, cfg)
コード例 #28
0
    if args.benchmark:
        model = NNUnet(args)
        batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
        log_dir = os.path.join(args.results, args.logname if args.logname is not None else "perf.json")
        callbacks = [
            LoggingCallback(
                log_dir=log_dir,
                global_batch_size=batch_size * args.gpus,
                mode=args.exec_mode,
                warmup=args.warmup,
                dim=args.dim,
            )
        ]
    elif args.exec_mode == "train":
        model = NNUnet(args)
        early_stopping = EarlyStopping(monitor="dice_mean", patience=args.patience, verbose=True, mode="max")
        callbacks = [early_stopping]
        if args.save_ckpt:
            model_ckpt = ModelCheckpoint(
                filename="{epoch}-{dice_mean:.2f}", monitor="dice_mean", mode="max", save_last=True
            )
            callbacks.append(model_ckpt)
    else:  # Evaluation or inference
        if ckpt_path is not None:
            model = NNUnet.load_from_checkpoint(ckpt_path)
        else:
            model = NNUnet(args)

    trainer = Trainer(
        logger=False,
        gpus=args.gpus,
コード例 #29
0
def main(conf):
    train_set = WhamDataset(
        conf["data"]["train_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        segment=conf["data"]["segment"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = WhamDataset(
        conf["data"]["valid_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    model = DPTNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    from asteroid.engine.schedulers import DPTNetScheduler

    schedulers = {
        "scheduler":
        DPTNetScheduler(optimizer,
                        len(train_loader) // conf["training"]["batch_size"],
                        64),
        "interval":
        "step",
    }

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        scheduler=schedulers,
        train_loader=train_loader,
        val_loader=val_loader,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        gradient_clip_val=conf["training"]["gradient_clipping"],
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
コード例 #30
0
ファイル: trainer.py プロジェクト: adobe/NLP-Cube
    def fit(self):
        if self.task not in ["tokenizer", "lemmatizer", "cwe", "tagger", "parser"]:
            raise Exception("Task must be one of: tokenizer, lemmatizer, cwe, tagger or parser.")

        with open(self.args.store + ".yaml", 'w') as f:
            yaml.dump({"language_map": self.language_map, "language_codes": self.language_codes}, f, sort_keys=True)

        enc = Encodings()
        enc.compute(self.doc_train, None)
        enc.save('{0}.encodings'.format(self.store_prefix))

        if self.task == "tokenizer":
            config = TokenizerConfig()
            no_space_lang = Tokenizer._detect_no_space_lang(self.doc_train)
            print("NO_SPACE_LANG = " + str(no_space_lang))
            config.no_space_lang = no_space_lang
        if self.task == "tagger":
            config = TaggerConfig()
        if self.task == "lemmatizer":
            config = LemmatizerConfig()
        if self.task == "parser":
            config = ParserConfig()
        if self.task == "cwe":
            config = CompoundConfig()
        config.lm_model = self.args.lm_model
        if self.args.config_file:
            config.load(self.args.config_file)
            if self.args.lm_model is not None:
                config.lm_model = self.args.lm_model
        config.save('{}.config'.format(self.args.store))

        if self.task != "tokenizer" and self.task != 'lemmatizer' and self.task != 'cwe':
            lm_model = config.lm_model
            parts = lm_model.split(':')
            if parts[0] not in ['transformer', 'fasttext', 'languasito']:
                print("Error: model prefix should be in the form of transformer: fasttext: or languasito:")
                sys.exit(0)
            if parts[0] == 'transformer':
                helper = LMHelperHF(device=self.args.lm_device, model=parts[1])
            elif parts[0] == 'fasttext':
                helper = LMHelperFT(device=self.args.lm_device, model=parts[1])
            elif parts[0] == 'languasito':
                helper = LMHelperLanguasito(device=self.args.lm_device, model=parts[1])
            helper.apply(self.doc_dev)
            helper.apply(self.doc_train)

        if self.task == "tokenizer":
            trainset = TokenizationDataset(self.doc_train)
            devset = TokenizationDataset(self.doc_dev, shuffle=False)
        elif self.task == 'parser' or self.task == 'tagger':
            trainset = MorphoDataset(self.doc_train)
            devset = MorphoDataset(self.doc_dev)
        elif self.task == 'lemmatizer':
            trainset = LemmaDataset(self.doc_train)
            devset = LemmaDataset(self.doc_dev)
        elif self.task == 'cwe':
            trainset = CompoundDataset(self.doc_train)
            devset = CompoundDataset(self.doc_dev)

        collate = MorphoCollate(enc)

        # per task specific settings
        callbacks = []
        if self.task == "tokenizer":
            early_stopping_callback = EarlyStopping(
                monitor='val/early_meta',
                patience=args.patience,
                verbose=True,
                mode='max'
            )
            parts = args.lm_model.split(':')
            if parts[0] == 'transformer':
                collate = TokenCollateHF(enc, lm_device=args.lm_device, lm_model=parts[1],
                                         no_space_lang=config.no_space_lang)
            else:
                collate = TokenCollateFTLanguasito(enc, lm_device=args.lm_device, lm_model=args.lm_model,
                                                   no_space_lang=config.no_space_lang)

            callbacks = [early_stopping_callback, Tokenizer.PrintAndSaveCallback(self.store_prefix)]
            model = Tokenizer(config=config, encodings=enc, language_codes=self.language_codes,
                              ext_word_emb=collate.get_embeddings_size(), max_seq_len=collate.max_seq_len)

        if self.task == "tagger":
            early_stopping_callback = EarlyStopping(
                monitor='val/early_meta',
                patience=args.patience,
                verbose=True,
                mode='max'
            )
            callbacks = [early_stopping_callback, Tagger.PrintAndSaveCallback(self.store_prefix)]
            model = Tagger(config=config, encodings=enc, language_codes=self.language_codes,
                           ext_word_emb=helper.get_embedding_size())

        if self.task == "parser":
            collate = MorphoCollate(enc, add_parsing=True, rhl_win_size=config.rhl_win_size)
            early_stopping_callback = EarlyStopping(
                monitor='val/early_meta',
                patience=args.patience,
                verbose=True,
                mode='max'
            )
            callbacks = [early_stopping_callback, Parser.PrintAndSaveCallback(self.store_prefix)]
            model = Parser(config=config, encodings=enc, language_codes=self.language_codes,
                           ext_word_emb=helper.get_embedding_size())

        if self.task == "lemmatizer":
            collate = Word2TargetCollate(enc)
            early_stopping_callback = EarlyStopping(
                monitor='val/early_meta',
                patience=args.patience,
                verbose=True,
                mode='max'
            )
            callbacks = [early_stopping_callback, Lemmatizer.PrintAndSaveCallback(self.store_prefix)]
            model = Lemmatizer(config=config, encodings=enc, language_codes=self.language_codes)

        if self.task == "cwe":
            collate = Word2TargetCollate(enc)
            early_stopping_callback = EarlyStopping(
                monitor='val/early_meta',
                patience=args.patience,
                verbose=True,
                mode='max'
            )
            callbacks = [early_stopping_callback, Compound.PrintAndSaveCallback(self.store_prefix)]
            model = Compound(config=config, encodings=enc, language_codes=self.language_codes)
            # extra check to see if there is actually any compound in this language
            if len(trainset._examples) == 0 or len(devset._examples) == 0:
                print("\nTrain/dev data for this language does not contain any compound words; there is nothing to train.")
                return

        # dataloaders
        train_loader = DataLoader(trainset, batch_size=self.args.batch_size, collate_fn=collate.collate_fn,
                                  shuffle=True,
                                  num_workers=self.args.num_workers)
        val_loader = DataLoader(devset, batch_size=self.args.batch_size, collate_fn=collate.collate_fn,
                                num_workers=self.args.num_workers)

        # pre-train checks
        resume_from_checkpoint = None
        if self.args.resume is True:
            resume_from_checkpoint = self.store_prefix + ".last"
            if not os.path.exists(resume_from_checkpoint):
                raise Exception("Resume from checkpoint: {} not found!".format(resume_from_checkpoint))

        """if self.args.gpus == 0:
            acc = 'ddp_cpu'
        else:
            acc = 'ddp'
        """

        trainer = pl.Trainer(
            gpus=args.gpus,
            accelerator=args.accelerator,
            #num_nodes=1,
            default_root_dir='data/',
            callbacks=callbacks,
            resume_from_checkpoint=resume_from_checkpoint,
            accumulate_grad_batches=args.accumulate_grad_batches,
            # limit_train_batches=100,
            # limit_val_batches=4,
        )

        # run fit
        print("\nStarting train\n")
        trainer.fit(model, train_loader, val_loader)