Exemple #1
0
def main():
    args = parse_args()
    seed_everything(args.seed)

    tb_logger = loggers.TensorBoardLogger("logs/")
    wandb_logger = loggers.WandbLogger(save_dir="logs/", project="xldst")
    assert wandb_logger.experiment.id
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join("ckpts", wandb_logger.experiment.id,
                              "{epoch}-{val_loss:.4f}"),
        verbose=True,
    )
    early_stop_callback = EarlyStopping(patience=2, verbose=True)
    trainer = Trainer.from_argparse_args(
        args,
        logger=[tb_logger, wandb_logger],
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stop_callback,
    )
    dm = CldstMBartDataModule(args)
    dm.prepare_data()

    dm.setup("fit")
    model = MBartDST(args)
    trainer.fit(model, datamodule=dm)

    dm.setup("test")
    trainer.test(datamodule=dm)
Exemple #2
0
def common_train(args, metric, model_class, build_method, task: str,
                 **model_kwargs):
    pl.seed_everything(args.seed)

    early_stop_callback = EarlyStopping(monitor=metric,
                                        min_delta=1e-5,
                                        patience=3,
                                        verbose=False,
                                        mode='max')
    checkpoint_callback = ModelCheckpoint(monitor=metric,
                                          save_top_k=1,
                                          verbose=True,
                                          mode='max',
                                          save_last=True)
    model = model_class(args, **model_kwargs)
    build_method(model)
    this_time = time.strftime("%m-%d_%H-%M-%S", time.localtime())
    try:
        import wandb
        logger = loggers.WandbLogger(save_dir='lightning_logs',
                                     name=f'{task}_{this_time}',
                                     project='ltp')
    except Exception as e:
        logger = loggers.TensorBoardLogger(save_dir='lightning_logs',
                                           name=f'{task}_{this_time}')
    trainer: Trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=[early_stop_callback],
        checkpoint_callback=checkpoint_callback)
    # Ready to train with new learning rate
    trainer.fit(model)
    trainer.test()
Exemple #3
0
    def get_logger(self, cfg: DictConfig,
                   save_dir: Path) -> pl_loggers.WandbLogger:
        """Returns the Weights and Biases (wandb) logger object (really an wandb Run object)
        The run object corresponds to a single execution of the script and is returned from `wandb.init()`.

        Args:
            run_id: Unique run id. If run id exists, will continue logging to that run.
            cfg: The entire config got from hydra, for purposes of logging the config of each run in wandb.
            save_dir: Root dir to save wandb log files

        Returns:
            wandb.wandb_sdk.wandb_run.Run: wandb run object. Can be used for logging.
        """
        # Some argument names to wandb are different from the attribute names of the class.
        # Pop the offending attributes before passing to init func.
        args_dict = asdict_filtered(self)
        run_name = args_dict.pop("run_name")
        run_id = args_dict.pop("run_id")

        # If `self.save_hyperparams()` is called in LightningModule, it will save the cfg passed as argument
        # cfg_dict = OmegaConf.to_container(cfg, resolve=True)

        wb_logger = pl_loggers.WandbLogger(name=run_name,
                                           id=run_id,
                                           save_dir=str(save_dir),
                                           **args_dict)

        return wb_logger
Exemple #4
0
def main(hparams):
    # clean up
    gc.collect()
    torch.cuda.empty_cache()

    logger = loggers.WandbLogger(name=hparams.log_name, project="ml4cg")

    model = Net(hparams)

    checkpoint_callback = ModelCheckpoint(
        filepath="checkpoints/{epoch}",
        save_top_k=20,
        verbose=True,
        monitor="val_loss",
        save_weights_only=True,
        period=1,
        mode="min",
        prefix="",
    )

    trainer = Trainer(
        logger=logger,
        checkpoint_callback=checkpoint_callback,
        gpus=hparams.gpus,
        max_epochs=hparams.max_epochs,
        num_sanity_val_steps=hparams.num_sanity_val_steps,
    )

    trainer.fit(model)
    trainer.test(model)
Exemple #5
0
def main(args):

    # create experiment dir
    experiment_dir = CKPT_DIR / args.experiment_name
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # logger
    logger = pl_loggers.WandbLogger(
        name=args.experiment_name,
        save_dir=LOG_DIR,
    )

    # early stop call back
    early_stop = EarlyStopping(monitor='val_loss',
                               patience=5,
                               strict=False,
                               verbose=False,
                               mode='min')

    # checkpoint
    checkpoint_callback = ModelCheckpoint(filepath=experiment_dir,
                                          save_top_k=3,
                                          verbose=True,
                                          monitor='val_loss',
                                          mode='min',
                                          prefix='')

    model = LightningModel(args)
    trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        early_stop_callback=early_stop,
        checkpoint_callback=checkpoint_callback)
    trainer.fit(model)
Exemple #6
0
def main(params: Optional[SlotAttentionParams] = None):
    if params is None:
        params = SlotAttentionParams()

    assert params.num_slots > 1, "Must have at least 2 slots."

    if params.is_verbose:
        print(f"INFO: limiting the dataset to only images with `num_slots - 1` ({params.num_slots - 1}) objects.")
        if params.num_train_images:
            print(f"INFO: restricting the train dataset size to `num_train_images`: {params.num_train_images}")
        if params.num_val_images:
            print(f"INFO: restricting the validation dataset size to `num_val_images`: {params.num_val_images}")

    clevr_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Lambda(rescale),  # rescale between -1 and 1
            transforms.Resize(params.resolution),
        ]
    )

    clevr_datamodule = CLEVRDataModule(
        data_root=params.data_root,
        max_n_objects=params.num_slots - 1,
        train_batch_size=params.batch_size,
        val_batch_size=params.val_batch_size,
        clevr_transforms=clevr_transforms,
        num_train_images=params.num_train_images,
        num_val_images=params.num_val_images,
        num_workers=params.num_workers,
    )

    print(f"Training set size (images must have {params.num_slots - 1} objects):", len(clevr_datamodule.train_dataset))

    model = SlotAttentionModel(
        resolution=params.resolution,
        num_slots=params.num_slots,
        num_iterations=params.num_iterations,
        empty_cache=params.empty_cache,
    )

    method = SlotAttentionMethod(model=model, datamodule=clevr_datamodule, params=params)

    logger_name = "slot-attention-clevr6"
    logger = pl_loggers.WandbLogger(project="slot-attention-clevr6", name=logger_name)

    trainer = Trainer(
        logger=logger if params.is_logger_enabled else False,
        accelerator="ddp" if params.gpus > 1 else None,
        num_sanity_val_steps=params.num_sanity_val_steps,
        gpus=params.gpus,
        max_epochs=params.max_epochs,
        log_every_n_steps=50,
        callbacks=[LearningRateMonitor("step"), ImageLogCallback(),] if params.is_logger_enabled else [],
    )
    trainer.fit(method)
    def get_logger(self, save_dir: Path) -> pl_loggers.WandbLogger:
        args_dict = asdict_filtered(self)
        run_name = args_dict.pop("run_name")
        run_id = args_dict.pop("run_id")

        wb_logger = pl_loggers.WandbLogger(name=run_name,
                                           id=run_id,
                                           save_dir=str(save_dir),
                                           **args_dict)

        return wb_logger
Exemple #8
0
def main(args):
    logger = pl_loggers.WandbLogger(experiment=None, save_dir=None)
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"ckpts/{args.reduction}_reduction/", monitor="val_loss")
    model = AutoencoderModel(args)
    lr_logger = LearningRateMonitor()
    trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=[lr_logger],
        checkpoint_callback=checkpoint_callback)
    trainer.fit(model)
def main():
    ''' Main '''

    encoder = update_static_mapping(
    )  # Update the static mapping before initializing Config
    C = Config(lite=True)
    init_seed(get_seed())

    ##################
    # INPUT PIPELINE #
    ##################

    # Read and format the csv
    df = pd.read_csv("gestop/data/static_gestures_data.csv")
    train, test = train_test_split(df, test_size=0.1, random_state=get_seed())
    train_X, train_Y = split_dataframe(train)
    test_X, test_Y = split_dataframe(test)

    # One Hot Encoding of the target classes
    train_Y = np.array(train_Y)
    test_Y = np.array(test_Y)

    train_Y = encoder.transform(train_Y)
    test_Y = encoder.transform(test_Y)

    train_loader = format_and_load(train_X, train_Y, C.static_batch_size)
    test_loader = format_and_load(test_X, test_Y, C.static_batch_size)

    static_net = StaticNet(C.static_input_dim, C.static_output_classes,
                           C.static_gesture_mapping)

    early_stopping = EarlyStopping(
        patience=3,
        verbose=True,
    )

    wandb_logger = pl_loggers.WandbLogger(save_dir='gestop/logs/',
                                          name='static_net',
                                          project='gestop')

    trainer = Trainer(gpus=1,
                      deterministic=True,
                      logger=wandb_logger,
                      min_epochs=C.min_epochs,
                      early_stop_callback=early_stopping)
    trainer.fit(static_net, train_loader, test_loader)
    trainer.test(static_net, test_dataloaders=test_loader)

    ################
    # SAVING MODEL #
    ################

    torch.save(static_net.state_dict(), C.static_path)
Exemple #10
0
def main(args):
    logger = pl_loggers.WandbLogger(experiment="example", save_dir=None)
    early_stop = EarlyStopping(monitor="val_loss")
    checkpoint_callback = ModelCheckpoint(dirpath="ckpts/", monitor="val_loss")
    model = ExampleModel(args)
    lr_logger = LearningRateLogger()
    trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=[early_stop, lr_logger],
        checkpoint_callback=checkpoint_callback)
    trainer.fit(model)
Exemple #11
0
 def get_logger(self):
     trainer_opt = self.trainer_opt
     if trainer_opt['logger'] == 'tensorboard':
         logger = loggers.TensorBoardLogger(self.minfo['logs_dir'])
         version = logger.experiment.log_dir.split('_')[-1]
     elif trainer_opt['logger'] == 'wandb':
         logger = loggers.WandbLogger(name=self.opt.expr_name,
                                      project=self.project_name,
                                      **trainer_opt['logger_kwargs'])
         version = 0
     else:
         logger = False
     return logger
Exemple #12
0
def main(args):

    # logger
    logger = pl_loggers.WandbLogger(name=None, save_dir=None, experiment=None)

    # early stop call back
    early_stop = EarlyStopping(monitor='val_loss',
                               patience=5,
                               strict=False,
                               verbose=False,
                               mode='min')

    model = LightningModel(args)
    trainer = Trainer.from_argparse_args(args,
                                         logger=logger,
                                         early_stop_callback=early_stop)
    trainer.fit(model)
Exemple #13
0
def common_train(args,
                 model_class,
                 task_info: TaskInfo,
                 build_method=default_build_method,
                 model_kwargs: dict = None):
    if model_kwargs is None:
        model_kwargs = {}

    pl.seed_everything(args.seed)

    early_stop_callback = EarlyStopping(monitor=f'val_{task_info.metric_name}',
                                        min_delta=1e-5,
                                        patience=args.patience,
                                        verbose=False,
                                        mode='max')
    checkpoint_callback = ModelCheckpoint(
        monitor=f'val_{task_info.metric_name}',
        save_top_k=1,
        verbose=True,
        mode='max',
        save_last=True)
    model = model_class(args, **model_kwargs)
    build_method(model, task_info)
    this_time = time.strftime("%m-%d_%H:%M:%S", time.localtime())

    try:
        import wandb
        logger = loggers.WandbLogger(project=args.project,
                                     save_dir='lightning_logs',
                                     name=f'{task_info.task_name}_{this_time}',
                                     offline=args.offline)
    except Exception as e:
        logger = loggers.TensorBoardLogger(
            save_dir='lightning_logs',
            name=f'{task_info.task_name}_{this_time}',
            default_hp_metric=False)
    trainer: Trainer = Trainer.from_argparse_args(
        args,
        logger=logger,
        callbacks=[early_stop_callback],
        checkpoint_callback=checkpoint_callback)
    # Ready to train with new learning rate
    trainer.fit(model)
    trainer.test()
Exemple #14
0
def get_loggers_callbacks(args, model=None):

    try:
        # Setup logger(s) params
        csv_logger_params = dict(
            save_dir="./experiments",
            name=os.path.join(*args.save_dir.split("/")[1:-1]),
            version=args.save_dir.split("/")[-1],
        )
        wandb_logger_params = dict(
            log_model=False,
            name=os.path.join(*args.save_dir.split("/")[1:]),
            offline=args.debug,
            project="utime",
            save_dir=args.save_dir,
        )
        loggers = [
            pl_loggers.CSVLogger(**csv_logger_params),
            pl_loggers.WandbLogger(**wandb_logger_params),
        ]
        if model:
            loggers[-1].watch(model)

        # Setup callback(s) params
        checkpoint_monitor_params = dict(
            filepath=os.path.join(args.save_dir,
                                  "{epoch:03d}-{eval_loss:.2f}"),
            monitor=args.checkpoint_monitor,
            save_last=True,
            save_top_k=1,
        )
        earlystopping_parameters = dict(
            monitor=args.earlystopping_monitor,
            patience=args.earlystopping_patience,
        )
        callbacks = [
            pl_callbacks.ModelCheckpoint(**checkpoint_monitor_params),
            pl_callbacks.EarlyStopping(**earlystopping_parameters),
            pl_callbacks.LearningRateMonitor(),
        ]

        return loggers, callbacks
    except AttributeError:
        return None, None
def main():
    logger.remove()
    logger.add(sys.stdout,
               colorize=True,
               format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> " +
               "| <level>{level}</level> " +
               "| <light-black>{file.path}:{line}</light-black> | {message}")
    hparams = parse_args()
    if hparams.restore:
        wandb.init(project=hparams.project, tags=hparams.tags)
        model = LevelClassification.load_from_checkpoint(hparams.restore)
        logger.info("Restored model")
    else:
        # wandb.init is called in LevelClassification
        model = LevelClassification(hparams)
        experiment_logger = loggers.WandbLogger(project=hparams.project,
                                                tags=hparams.tags)
        hparams.checkpoint_dir = os.path.join(experiment_logger.experiment.dir,
                                              "checkpoints")
        checkpoint_cb = callbacks.ModelCheckpoint(hparams.checkpoint_dir,
                                                  save_top_k=1)
        trainer = pl.Trainer(logger=experiment_logger,
                             gpus=1 if hparams.device == "cuda" else 0,
                             checkpoint_callback=checkpoint_cb,
                             callbacks=[EmbeddingsCallback()],
                             early_stop_callback=callbacks.EarlyStopping(),
                             fast_dev_run=hparams.debug)
        trainer.fit(model)
    model.freeze()
    baseline_datasets = []
    logger.info("Baselines {}", os.listdir(hparams.baseline_level_dir))
    for i, baseline_level_dir in enumerate(
            sorted(os.listdir(hparams.baseline_level_dir))):
        baseline_dataset = LevelSnippetDataset(
            level_dir=os.path.join(os.getcwd(), hparams.baseline_level_dir,
                                   baseline_level_dir),
            slice_width=model.dataset.slice_width,
            token_list=model.dataset.token_list)
        baseline_datasets.append(baseline_dataset)
    visualize_embeddings(model.dataset, model, "test", hparams, None,
                         baseline_datasets)
Exemple #16
0
    def tune_train_once(config,
                        checkpoint_dir=None,
                        args: argparse.Namespace = None,
                        model_class: type = None,
                        build_method=None,
                        task_info: TaskInfo = None,
                        model_kwargs: dict = None,
                        resume: str = None,
                        group: str = None,
                        log_dir: str = None,
                        **kwargs):
        if resume is None:
            resume = 'all'
        args_vars = vars(args)
        args_vars.update(config)

        pl.seed_everything(args.seed)
        pl_loggers = [
            loggers.CSVLogger(save_dir=tune.get_trial_dir(),
                              name="",
                              version="."),
            loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                      name="",
                                      version=".",
                                      default_hp_metric=False),
        ]

        try:
            import wandb
            pl_loggers.append(
                loggers.WandbLogger(save_dir=log_dir or 'tune_lightning_logs',
                                    project=args.project,
                                    name=tune.get_trial_name(),
                                    id=tune.get_trial_id(),
                                    offline=args.offline,
                                    group=group))
        except Exception:
            pass

        trainer_args = dict(
            logger=pl_loggers,
            progress_bar_refresh_rate=0,
            callbacks=[
                TuneReportCheckpointCallback(metrics={
                    f'tune_{task_info.metric_name}':
                    f'{task_info.task_name}/val_{task_info.metric_name}'
                },
                                             filename="tune.ckpt",
                                             on="validation_end")
            ])
        if checkpoint_dir and resume == 'all':
            trainer_args['resume_from_checkpoint'] = os.path.join(
                checkpoint_dir, "tune.ckpt")

        # fix slurm trainer
        os.environ["SLURM_JOB_NAME"] = "bash"
        model = model_class(args, **model_kwargs)
        build_method(model, task_info)
        trainer: Trainer = Trainer.from_argparse_args(args, **trainer_args)
        if checkpoint_dir and resume == 'model':
            ckpt = pl_load(os.path.join(checkpoint_dir, "tune.ckpt"),
                           map_location=lambda storage, loc: storage)
            model = model._load_model_state(ckpt)
            trainer.current_epoch = ckpt["epoch"]

        trainer.fit(model)
Exemple #17
0
    def tune_train(args,
                   model_class,
                   task_info: TaskInfo,
                   build_method=default_build_method,
                   model_kwargs: dict = None,
                   tune_config=None):
        if model_kwargs is None:
            model_kwargs = {}
        this_time = time.strftime("%m-%d_%H:%M:%S", time.localtime())
        experiment_name = f'{task_info.task_name}_{this_time}'

        if tune_config is None:
            config = {
                # 3e-4 for Small, 1e-4 for Base, 5e-5 for Large
                "lr":
                tune.loguniform(args.tune_min_lr, args.tune_max_lr),

                # -1 for disable, 0.8 for Base/Small, 0.9 for Large
                "layerwise_lr_decay_power":
                tune.choice([0.8, 0.9]),

                # lr scheduler
                "lr_scheduler":
                tune.choice([
                    'linear_schedule_with_warmup',
                    'polynomial_decay_schedule_with_warmup'
                ]),
            }
        else:
            config = tune_config
        if torch.cuda.is_available():
            resources_per_trial = {
                "cpu": args.tune_cpus_per_trial,
                "gpu": args.tune_gpus_per_trial
            }
        else:
            resources_per_trial = {"cpu": args.tune_cpus_per_trial}
        print("resources_per_trial", resources_per_trial)

        tune_dir = os.path.abspath('tune_lightning_logs')

        analysis = tune.run(
            tune.with_parameters(
                tune_train_once,
                args=args,
                task_info=task_info,
                model_class=model_class,
                build_method=build_method,
                model_kwargs=model_kwargs,
                resume=args.tune_resume,
                group=experiment_name,
                log_dir=tune_dir,
            ),
            mode="max",
            config=config,
            num_samples=args.tune_num_samples,
            metric=f'tune_{task_info.metric_name}',
            name=experiment_name,
            progress_reporter=CLIReporter(
                parameter_columns=list(config.keys()),
                metric_columns=[
                    "loss", f'tune_{task_info.metric_name}',
                    "training_iteration"
                ]),
            callbacks=[TBXLoggerCallback(),
                       CSVLoggerCallback()],
            resources_per_trial=resources_per_trial,
            scheduler=ASHAScheduler(
                max_t=args.max_epochs + 1,  # for test
                grace_period=args.min_epochs),
            queue_trials=True,
            keep_checkpoints_num=args.tune_keep_checkpoints_num,
            checkpoint_score_attr=f'tune_{task_info.metric_name}',
            local_dir=tune_dir,
        )
        print("Best hyperparameters found were: ", analysis.best_config)
        print("Best checkpoint: ", analysis.best_checkpoint)

        args_vars = vars(args)
        args_vars.update(analysis.best_config)
        model = model_class.load_from_checkpoint(os.path.join(
            analysis.best_checkpoint, "tune.ckpt"),
                                                 hparams=args,
                                                 **model_kwargs)

        pl_loggers = [
            loggers.CSVLogger(save_dir=tune.get_trial_dir(),
                              name="",
                              version="."),
            loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(),
                                      name="",
                                      version=".",
                                      default_hp_metric=False),
        ]

        try:
            import wandb
            pl_loggers.append(
                loggers.WandbLogger(save_dir=tune_dir,
                                    project=args.project,
                                    name=tune.get_trial_name(),
                                    id=tune.get_trial_id(),
                                    offline=args.offline,
                                    group=experiment_name))
        except Exception:
            pass

        trainer: Trainer = Trainer.from_argparse_args(args, logger=pl_loggers)
        build_method(model, task_info)
        trainer.test(model)
def main():
    ''' Main '''

    C = Config(lite=True)
    init_seed(C.seed_val)

    ##################
    # INPUT PIPELINE #
    ##################

    # Read and format the csv
    df = pd.read_csv("data/static_gestures_data.csv")
    train, test = train_test_split(df, test_size=0.1, random_state=C.seed_val)
    train_X, train_Y = split_dataframe(train)
    test_X, test_Y = split_dataframe(test)

    # One Hot Encoding of the target classes
    train_Y = np.array(train_Y)
    test_Y = np.array(test_Y)

    le = LabelEncoder()
    le.fit(train_Y)

    # Store encoding to disk
    le_name_mapping = dict(
        zip([int(i) for i in le.transform(le.classes_)], le.classes_))
    logging.info(le_name_mapping)
    with open('data/static_gesture_mapping.json', 'w') as f:
        f.write(json.dumps(le_name_mapping))

    train_Y = le.transform(train_Y)
    test_Y = le.transform(test_Y)

    train_loader = format_and_load(train_X, train_Y, C.static_batch_size)
    test_loader = format_and_load(test_X, test_Y, C.static_batch_size)

    gesture_net = GestureNet(C.static_input_dim, C.static_output_classes,
                             C.static_gesture_mapping)

    early_stopping = EarlyStopping(
        patience=3,
        verbose=True,
    )

    wandb_logger = pl_loggers.WandbLogger(save_dir='logs/',
                                          name='gesture_net',
                                          project='gestop')

    trainer = Trainer(gpus=1,
                      deterministic=True,
                      logger=wandb_logger,
                      min_epochs=C.min_epochs,
                      early_stop_callback=early_stopping)
    trainer.fit(gesture_net, train_loader, test_loader)
    # gesture_net.load_state_dict(torch.load(PATH))
    trainer.test(gesture_net, test_dataloaders=test_loader)

    ################
    # SAVING MODEL #
    ################

    torch.save(gesture_net.state_dict(), C.static_path)
Exemple #19
0
def main():
    ''' Main '''

    parser = argparse.ArgumentParser(description='A program to train a neural network \
    to recognize dynamic hand gestures.')
    parser.add_argument("--exp-name", help="The name with which to log the run.", type=str)

    args = parser.parse_args()

    C = Config(lite=True, pretrained=False)
    init_seed(C.seed_val)

    ##################
    # INPUT PIPELINE #
    ##################

    train_x, test_x, train_y, test_y, gesture_mapping = read_data(C.seed_val)
    with open('gestop/data/dynamic_gesture_mapping.json', 'w') as f:
        f.write(json.dumps(gesture_mapping))

    # Higher order function to pass configuration as argument
    shrec_to_mediapipe = partial(format_shrec, C)
    user_to_mediapipe = partial(format_user, C)

    # Custom transforms to prepare data.
    shrec_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(normalize),
        transforms.Lambda(resample_and_jitter),
        transforms.Lambda(shrec_to_mediapipe),
    ])
    user_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(torch.squeeze),
        transforms.Lambda(resample_and_jitter),
        transforms.Lambda(user_to_mediapipe),
    ])

    train_loader = DataLoader(ShrecDataset(train_x, train_y, shrec_transform, user_transform),
                              num_workers=10, batch_size=C.dynamic_batch_size,
                              collate_fn=choose_collate(variable_length_collate, C))
    val_loader = DataLoader(ShrecDataset(test_x, test_y, shrec_transform, user_transform),
                            num_workers=10, batch_size=C.dynamic_batch_size,
                            collate_fn=choose_collate(variable_length_collate, C))

    ############
    # TRAINING #
    ############

    # Use pretrained SHREC model
    if C.pretrained:
        model = ShrecNet(C.dynamic_input_dim, C.shrec_output_classes, gesture_mapping)
        model.load_state_dict(torch.load(C.shrec_path))
        model.replace_layers(C.dynamic_output_classes)
    else:
        model = ShrecNet(C.dynamic_input_dim, C.dynamic_output_classes, gesture_mapping)
        model.apply(init_weights)

    early_stopping = EarlyStopping(
        patience=5,
        verbose=True,
    )

    # No name is given as a command line flag.
    if args.exp_name is None:
        args.exp_name = "default"

    wandb_logger = pl_loggers.WandbLogger(save_dir='gestop/logs/',
                                          name=args.exp_name,
                                          project='gestop')

    trainer = Trainer(gpus=1,
                      deterministic=True,
                      logger=wandb_logger,
                      min_epochs=20,
                      accumulate_grad_batches=C.grad_accum,
                      early_stop_callback=early_stopping)

    trainer.fit(model, train_loader, val_loader)

    torch.save(model.state_dict(), C.dynamic_path)

    trainer.test(model, test_dataloaders=val_loader)
Exemple #20
0
    "h",
    "ʊ",
    "ʧ",
    "l",
    "w",
    "ʤ",
    "o",
    "X",
]

model.change_vocab(labels)
model.hparams.learning_rate = config.lr_frozen
model.hparams.betas = config.betas
callback_finetune = FinetuneEncoderDecoder(decoder_lr=config.lr_unfrozen,
                                           encoder_initial_lr_div=10)
wb_logger = pl_loggers.WandbLogger()
trainer = pl.Trainer(
    gpus=1,
    max_epochs=3,
    check_val_every_n_epoch=1,
    log_every_n_steps=1,
    num_sanity_val_steps=2,
    deterministic=True,
    logger=wb_logger,
    callbacks=[
        LogSpectrogramsCallback(),
        LogResultsCallback(), callback_finetune
    ],
)
trainer.fit(model=model, datamodule=full_dm)