示例#1
0
                           bounds=bounds,
                           options=args['options'])

        self.baseline, self.magnitude, self.decay = map(float, opt_res.x)

        return opt_res

if __name__ == "__main__":
    process = ExpHawkes(0.5, 0.8, 1.0)

    raw_data = process.generate_sequences(30, 100)
    data = []
    for idx in range(raw_data.shape[0]):
        data.append(raw_data[idx])

    dataset = GeneralDataset(data)
    batch = collate_no_marks(dataset)

    fit_process = ExpHawkes(2, 2, 2)
    args = {
        'x0': [0.3, 0.3, 0.5],
        'method': 'L-BFGS-B',
        'options': {
            'disp': True,
            'ftol': 1e-05,
            'gtol': 1e-04
        }
    }
    res = fit_process.fit(batch, args)
    print(res)
示例#2
0
    model = model_gen(True)

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('[Info] Number of parameters: {}'.format(num_params))

    save_name = name + '_' + model_name

    print('Start Processing:', save_name)

    data = []
    with open(save_file, 'r') as f:
        reader = csv.reader(f)
        for row in reader:
            data.append(torch.Tensor(list(map(float, row))))

    whole_dataset = GeneralDataset(data)
    whole_dataset.log_transform_rnn()

    train_dataset, validation_dataset, test_dataset = whole_dataset.train_val_test_split(
        0.6, 0.2, 0.2, seed=1)

    rnn_mean, rnn_std = train_dataset.rnn_statistics()

    train_dataset.normalise_data(rnn_mean=rnn_mean, rnn_std=rnn_std)
    validation_dataset.normalise_data(rnn_mean=rnn_mean, rnn_std=rnn_std)
    test_dataset.normalise_data(rnn_mean=rnn_mean, rnn_std=rnn_std)

    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=L2REG)
    save_path = os.path.join(ROOT_SAVE_FOLDER, save_name)
    valid_hist = model.train_model(train_dataset,
                                   EPOCHS,
示例#3
0
    def intensity(self, events):
        dataset = GeneralDataset([events])
        batch = collate_no_marks(dataset)

        return self.forward(batch)[0]