コード例 #1
0
    def _load_model(self):
        config_path = Path(self.args.new_model_path) / "config.{}.pkl".format(self.args.model_type.value)
        self.config = pkl_load(config_path)

        self.model = MixModel(config=self.config, model_type=self.args.model_type)
        model_path = Path(self.args.new_model_path) / "pytorch_model.{}.bin".format(self.args.model_type.value)
        ckpt = torch.load(model_path, map_location=torch.device('cpu'))  # using cpu to load model
        self.model.load_state_dict(ckpt)
        self.model.to(self.args.device)
コード例 #2
0
def main(args):
    # general set up
    random.seed(13)
    np.random.seed(13)
    torch.manual_seed(13)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(13)
    conf = "tlstm.conf"

    # load training data
    if args.do_train:
        train_data = pkl_load("../data/tlstm_sync/data_train.pkl")
        train_elapsed_data = pkl_load("../data/tlstm_sync/elapsed_train.pkl")
        train_labels = pkl_load("../data/tlstm_sync/label_train.pkl")
        # init config
        input_dim = train_data[0].shape[2]
        output_dim = train_labels[0].shape[1]
        config = TLSTMConfig(input_dim, output_dim, args.hidden_dim,
                             args.fc_dim, args.dropout_rate)
        # init TLSTM model
        model = TLSTM(config=config)
        model.to(args.device)
        # training
        train(args, model, train_data, train_elapsed_data, train_labels)
        # save model and config
        torch.save(model.state_dict(),
                   Path(args.model_path) / "pytorch_model.bin")
        pkl_save(config, Path(args.config_path) / conf)

    # load test data
    if args.do_test:
        test_data = pkl_load("../data/tlstm_sync/data_test.pkl")
        test_elapsed_data = pkl_load("../data/tlstm_sync/elapsed_test.pkl")
        test_labels = pkl_load("../data/tlstm_sync/label_test.pkl")
        config = pkl_load(Path(args.config_path) / conf)
        model = TLSTM(config=config)
        model.load_state_dict(
            torch.load(Path(args.model_path) / "pytorch_model.bin"))
        model.to(args.device)
        test(args, model, test_data, test_elapsed_data, test_labels)
コード例 #3
0
def load_data(features_path, features_filename, data_filename):
    fea2id, features = pkl_load(f"{features_filename}", features_path)
    data = pkl_load(f"{data_filename}", features_path)
    return fea2id, data
コード例 #4
0
    embeddings = np.zeros(emb_dim).reshape(1, -1)

    for idx, code in enumerate(vocab):
        code2index[code] = idx + 1

    np.random.seed(2)
    embeddings = np.concatenate(
        [embeddings, np.random.rand(len(vocab) + 1, emb_dim)], axis=0)

    index2code = {v: k for k, v in code2index.items()}

    return embeddings, code2index, index2code


if __name__ == '__main__':
    trs = pkl_load("../data/tlstm_sync/data_train.pkl")
    ttrs = pkl_load("../data/tlstm_sync/elapsed_train.pkl")
    trsl = pkl_load("../data/tlstm_sync/label_train.pkl")
    ntrs, s1 = ohe2idx(trs)

    tss = pkl_load("../data/tlstm_sync/data_test.pkl")
    ttss = pkl_load("../data/tlstm_sync/elapsed_test.pkl")
    tssl = pkl_load("../data/tlstm_sync/label_test.pkl")
    ntss, s2 = ohe2idx(tss)

    # create a embedding with dim as 10
    emb, c2i, i2c = random_generate_embeddings(s1.union(s2), 10)

    conf = SeqEmbEHRConfig(input_dim=10,
                           output_dim=2,
                           hidden_dim=64,
コード例 #5
0
def main(args):
    # general set up (random see for reproducibility, we default seed as 13)
    random.seed(13)
    np.random.seed(13)
    torch.manual_seed(13)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(13)

    try:
        args.model_type = MODEL_TYPE_FLAGS[args.model_type]
    except ValueError:
        raise RuntimeError("we support: lstm, tlstm but get {}".format(
            args.model_type))

    try:
        args.loss_mode = MODEL_LOSS_MODES[args.loss_mode]
    except ValueError:
        raise RuntimeError("we support: lstm, tlstm but get {}".format(
            args.loss_mode))

    # load data
    # if using TLSMT the data have 4 components as non-seq, seq, time elapse, label
    # if using LSTM the data have 3 components as non-seq, seq, label
    # seq data can have different seq length but encoded feature dim must be the same
    # The data should be in format as tuple of list of numpy arrays as [(np.array, np.array, np.array, np.array), ...]
    train_data_loader = None
    if args.do_train:
        train_data = pkl_load(args.train_data_path)
        train_data_loader = SeqEHRDataLoader(
            train_data,
            args.model_type,
            args.loss_mode,
            args.batch_size,
            task='train').create_data_loader()
        args.total_step = len(train_data_loader)
        # collect input dim for model init (seq, dim)
        args.nonseq_input_dim = train_data[0][0].shape
        args.seq_input_dim = train_data[0][1].shape

        if args.sampling_weight:
            # the data should be a 1-D numpy array of  1/ratio of each class
            # (class with more samples should have low weights)
            args.sampling_weight = pkl_load(args.sampling_weight)
            args.logger.info("using sample weights as {}".format(
                args.sampling_weight))

    test_data_loader = None
    if args.do_test:
        test_data = pkl_load(args.test_data_path)
        # create data loader (pin_memory is set to True) -> (B, S, T)
        test_data_loader = SeqEHRDataLoader(test_data,
                                            args.model_type,
                                            args.loss_mode,
                                            args.batch_size,
                                            task='test').create_data_loader()

    # init task runner
    task_runner = SeqEHRTrainer(args)

    # training
    if args.do_train:
        args.logger.info("start training...")
        task_runner.train(train_data_loader)

    # prediction
    if args.do_test:
        args.logger.info("start test...")
        task_runner.predict(test_data_loader, do_eval=args.do_eval)