def restore_model(
    model_dir: Union[Path,str,None] = None,
) -> torch.nn.Module:

    # verify model source
    args_file = Path(model_dir) / 'args.pkl'
    assert args_file.exists()

    # restore arguments namespace
    with args_file.open('r') as f:
        args = pickle.load(f)

    # set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device)

    # instantiate model
    model_cls = utils.create_model_class(args.model_name)
    model = model_cls(args)

    # load model parameters
    _, checkpoint_file = utils.create_output_paths(args)
    load_obj = torch.load(checkpoint_file.as_posix(), map_location=device)
    model_dict = load_obj['model']
    model.load_state_dict(model_dict)

    # send to device and set to eval mode
    model.to(device)
    model.eval()

    return model
Example #2
0
    def __init__(
        self,
        run_dir: Union[Path, str, None] = None,
        device: Union[str, None] = None,
        save: bool = True,
    ):
        self.run_dir = Path(run_dir).resolve()
        if self.run_dir.is_file():
            self.run_dir = self.run_dir.parent
        assert self.run_dir.exists() and self.run_dir.is_dir()
        self.args_file = self.run_dir / 'args.pkl'
        assert self.args_file.exists()
        self.run_dir_short = self.run_dir.relative_to(
            self.run_dir.parent.parent)
        self.device = device
        self.save = save

        with self.args_file.open('rb') as f:
            args = pickle.load(f)
        self.args = TestArguments().parse(existing_namespace=args)
        self.is_classification = not self.args.regression

        if self.device is None:
            self.device = self.args.device
        if self.device == 'auto' or self.device.startswith('cuda'):
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.args.device = self.device
        self.device = torch.device(self.device)

        self.output_dir = Path(self.args.output_dir)
        if not self.output_dir.is_absolute():
            self.output_dir = self.args_file.parent
        assert self.output_dir.exists()
        self.analysis_dir = self.output_dir / 'analysis'
        shutil.rmtree(self.analysis_dir, ignore_errors=True)
        self.analysis_dir.mkdir()

        # load paths
        self.test_data_file, checkpoint_file = utils.create_output_paths(
            self.args)

        # instantiate model
        model_cls = utils.create_model_class(self.args.model_name)
        self.model = model_cls(self.args)
        self.model = self.model.to(self.device)
        self.model.eval()

        # load the model checkpoint
        print(f"Model checkpoint: {checkpoint_file.as_posix()}")
        load_obj = torch.load(checkpoint_file.as_posix(),
                              map_location=self.device)
        model_dict = load_obj['model']
        self.model.load_state_dict(model_dict)

        self.training_output = None
        self.test_data = None
        self.valid_indices_data_loader = None
        self.elm_predictions = None
Example #3
0
def get_model(args: argparse.Namespace, logger: logging.Logger):
    _, model_cpt_path = utils.create_output_paths(args)
    gen_type_suffix = '_' + re.split(
        '[_.]', args.input_file)[-2] if args.generated else ''
    model_name = args.model_name + gen_type_suffix
    accepted_preproc = ['wavelet', 'unprocessed']

    model_cpt_file = os.path.join(
        model_cpt_path, f'{args.model_name}_lookahead_{args.label_look_ahead}'
        f'{gen_type_suffix}'
        f'{"_" + args.data_preproc if args.data_preproc in accepted_preproc else ""}'
        f'{"_" + args.balance_data if args.balance_data else ""}.pth')

    raw_model = (multi_features_ds_v2_model.RawFeatureModel(args)
                 if args.raw_num_filters > 0 else None)
    fft_model = (multi_features_ds_v2_model.FFTFeatureModel(args)
                 if args.fft_num_filters > 0 else None)
    cwt_model = (multi_features_ds_v2_model.DWTFeatureModel(args)
                 if args.wt_num_filters > 0 else None)
    features = [
        type(f).__name__ for f in [raw_model, fft_model, cwt_model] if f
    ]

    logger.info(f'Found {model_name} state dict at {model_cpt_file}.')
    model_cls = utils.create_model(args.model_name)
    if 'MULTI' in args.model_name.upper():
        model = model_cls(args, raw_model, fft_model, cwt_model)
    else:
        model = model_cls(args)
    state_dict = torch.load(model_cpt_file,
                            map_location=torch.device(args.device))['model']
    model.load_state_dict(state_dict)
    logger.info(f'Loaded {model_name} state dict.')

    model.layers = OrderedDict([
        child for child in model.named_modules()
        if hasattr(child[1], 'weight')
    ])

    return model.to(args.device)
Example #4
0
def train_loop(
    input_args: Union[list, dict, None] = None,
    trial=None,  # optuna `trial` object
    _rank: Union[
        int,
        None] = None,  # process rank for data parallel dist. training; *must* be last arg
) -> dict:
    """Run a training pipeline: parse inputs, prepare data, create model, train over epochs.

    Args:
    -----
        input_args (list|dict): (Optional) Input arguements as dict or list of strings
        trial: (Optional) Optuna trial object to report training progress and enable pruning
        _rank: (Optional) Used for Distributed Data Parallel training by `distributed_train.py`
    """

    # parse input args
    args_obj = TrainArguments()
    if input_args and isinstance(input_args, dict):
        # format dict into list
        arg_list = []
        for key, value in input_args.items():
            if isinstance(value, bool):
                if value is True:
                    arg_list.append(f'--{key}')
            else:
                arg_list.append(f'--{key}={value}')
        input_args = arg_list
    args = args_obj.parse(arg_list=input_args)

    # output directory and files
    output_dir = Path(args.output_dir).resolve()
    args.output_dir = output_dir.as_posix()
    shutil.rmtree(output_dir.as_posix(), ignore_errors=True)
    output_dir.mkdir(parents=True)

    if args.regression:
        args.data_preproc = 'regression'
        args.label_look_ahead = 0
        args.truncate_buffer = 0
        args.oversample_active_elm = False

    if args.regression != 'log':
        args.inverse_label_weight = False

    output_file = output_dir / args.output_file
    log_file = output_dir / args.log_file
    args_file = output_dir / args.args_file
    test_data_file, checkpoint_file = utils.create_output_paths(args)

    # create LOGGER
    LOGGER = utils.get_logger(script_name=__name__, log_file=log_file)
    LOGGER.info(args_obj.make_args_summary_string())

    LOGGER.info(f"  Output directory: {output_dir.resolve().as_posix()}")

    # save args
    LOGGER.info(f"  Saving argument file: {args_file.as_posix()}")
    with args_file.open('wb') as f:
        pickle.dump(args, f)
    LOGGER.info(f"  File size: {args_file.stat().st_size/1e3:.1f} kB")

    # setup device
    if args.device == 'auto':
        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if _rank is not None:
        # override args.device for multi-GPU distributed data parallel training
        args.device = f'cuda:{_rank}'
        LOGGER.info(
            f'  Distributed data parallel: process rank {_rank} on GPU {args.device}'
        )
    device = torch.device(args.device)
    LOGGER.info(f'------>  Target device: {device}')

    # create train, valid and test data
    data_cls = utils.create_data_class(args.data_preproc)
    data_obj = data_cls(args, LOGGER)
    train_data, valid_data, test_data = data_obj.get_data(verbose=False)

    # dump test data into a file
    if not args.dry_run:
        LOGGER.info(
            f"  Test data will be saved to: {test_data_file.as_posix()}")
        with test_data_file.open('wb') as f:
            pickle.dump(
                {
                    "signals": test_data[0],
                    "labels": test_data[1],
                    "sample_indices": test_data[2],
                    "window_start": test_data[3],
                    "elm_indices": test_data[4],
                },
                f,
            )
        LOGGER.info(f"  File size: {test_data_file.stat().st_size/1e6:.1f} MB")

    # create datasets
    train_dataset = dataset.ELMDataset(args, *train_data[0:4], logger=LOGGER)
    valid_dataset = dataset.ELMDataset(args, *valid_data[0:4], logger=LOGGER)

    # training and validation dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # model class and model instance
    model_class = utils.create_model_class(args.model_name)
    model = model_class(args)
    model = model.to(device)

    # distribute model for data-parallel training
    if _rank is not None:
        model = DDP(model, device_ids=[_rank])

    LOGGER.info(f"------>  Model: {args.model_name}       ")

    # display model details
    input_size = (
        args.batch_size,
        1,
        args.signal_window_size,
        8,
        8,
    )
    x = torch.rand(*input_size)
    x = x.to(device)
    if _rank is None:
        # skip torchinfo.summary if DistributedDataParallel
        tmp_io = io.StringIO()
        sys.stdout = tmp_io
        torchinfo.summary(model, input_size=input_size, device=device)
        sys.stdout = sys.__stdout__
        LOGGER.info("\t\t\t\tMODEL SUMMARY")
        LOGGER.info(tmp_io.getvalue())
    LOGGER.info(f'  Batched input size: {x.shape}')
    LOGGER.info(f"  Batched output size: {model(x).shape}")
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    LOGGER.info(f"  Model contains {n_params} trainable parameters!")

    # optimizer
    if args.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
            momentum=args.momentum,
            dampening=args.dampening,
        )
    else:
        raise ValueError

    # get the lr scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,
        patience=2,
        verbose=True,
    )

    # loss function
    if args.regression:
        criterion = torch.nn.MSELoss(reduction="none")
    else:
        criterion = torch.nn.BCEWithLogitsLoss(reduction="none")

    # define variables for ROC and loss
    best_score = -np.inf

    # instantiate training object
    use_rnn = True if args.data_preproc == "rnn" else False
    engine = trainer.Run(
        model,
        device=device,
        criterion=criterion,
        optimizer=optimizer,
        use_focal_loss=args.focal_loss,
        use_rnn=use_rnn,
        inverse_label_weight=args.inverse_label_weight,
    )

    # containers to hold train and validation losses
    train_loss = np.empty(0)
    valid_loss = np.empty(0)
    scores = np.empty(0)
    if args.regression:
        pass
    else:
        roc_scores = np.empty(0)

    outputs = {}

    training_start_time = time.time()
    # iterate through all the epochs
    LOGGER.info(f"  Begin training loop with {args.n_epochs} epochs")
    for epoch in range(args.n_epochs):
        start_time = time.time()

        # train over an epoch
        avg_loss = engine.train(
            train_loader,
            epoch,
            print_every=args.train_print_every,
        )
        if args.regression:
            avg_loss = np.sqrt(avg_loss)
        train_loss = np.append(train_loss, avg_loss)

        # evaluate validation data
        avg_val_loss, preds, valid_labels = engine.evaluate(
            valid_loader, print_every=args.valid_print_every)
        if args.regression:
            avg_val_loss = np.sqrt(avg_val_loss)
        valid_loss = np.append(valid_loss, avg_val_loss)

        # step the learning rate scheduler
        scheduler.step(avg_val_loss)

        if args.regression:
            # R2 score
            score = r2_score(valid_labels, preds)
            scores = np.append(scores, score)
        else:
            # F1 scoring
            score = f1_score(valid_labels,
                             (preds > args.threshold).astype(int))
            scores = np.append(scores, score)
            # ROC scoring
            roc_score = roc_auc_score(valid_labels, preds)
            roc_scores = np.append(roc_scores, roc_score)

        elapsed = time.time() - start_time

        LOGGER.info(
            f"Epoch: {epoch+1:03d} \ttrain loss: {avg_loss:.3f} \tval. loss: {avg_val_loss:.3f} "
            f"\tscore: {score:.3f} \ttime elapsed: {elapsed:.1f} s")

        # update and save outputs
        outputs['train_loss'] = train_loss
        outputs['valid_loss'] = valid_loss
        outputs['scores'] = scores
        if args.regression:
            pass
        else:
            outputs['roc_scores'] = roc_scores

        with open(output_file.as_posix(), "w+b") as f:
            pickle.dump(outputs, f)

        # track best f1 score and save model
        if score > best_score or epoch == 0:
            best_score = score
            LOGGER.info(f"Epoch: {epoch+1:03d} \tBest Score: {best_score:.3f}")
            if not args.dry_run:
                LOGGER.info(f"  Saving model to: {checkpoint_file.as_posix()}")
                model_data = {
                    "model": model.state_dict(),
                    "preds": preds,
                }
                torch.save(model_data, checkpoint_file.as_posix())
                LOGGER.info(
                    f"  File size: {checkpoint_file.stat().st_size/1e3:.1f} kB"
                )
                if args.save_onnx:
                    input_name = ['signal_window']
                    output_name = ['micro_prediction']
                    onnx_file = Path(args.output_dir) / 'checkpoint.onnx'
                    LOGGER.info(f"  Saving to ONNX: {onnx_file.as_posix()}")
                    torch.onnx.export(model,
                                      x[0].unsqueeze(0),
                                      onnx_file.as_posix(),
                                      input_names=input_name,
                                      output_names=output_name,
                                      verbose=True,
                                      opset_version=11)
                    LOGGER.info(
                        f"  File size: {onnx_file.stat().st_size/1e3:.1f} kB")

        # optuna hook to monitor training epochs
        if trial is not None and optuna is not None:
            trial.report(score, epoch)
            # save outputs as lists in trial user attributes
            for key, item in outputs.items():
                trial.set_user_attr(key, item.tolist())
            if trial.should_prune():
                LOGGER.info("--------> Trial pruned by Optuna")
                for handler in LOGGER.handlers[:]:
                    handler.close()
                    LOGGER.removeHandler(handler)
                optuna.TrialPruned()

            LOGGER.info(scores)
            LOGGER.info(trial.user_attrs['scores'])

    if args.do_analysis:
        run = Analysis(output_dir)
        run.plot_training_epochs()
        run.plot_valid_indices_analysis()

    total_elapsed = time.time() - training_start_time
    LOGGER.info(f'Training complete in {total_elapsed:0.1f}')

    # shut down logger handlers
    for handler in LOGGER.handlers[:]:
        handler.close()
        LOGGER.removeHandler(handler)

    return outputs