コード例 #1
0
ファイル: test_regression.py プロジェクト: vikibytes/finance
def test_ridge_regression():
    stock_d = testdata()
    ti = TechnicalIndicators(stock_d)

    filename = 'test_N225_ridge.pickle'
    clffile = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..',
                           'clf', filename)

    if os.path.exists(clffile):
        os.remove(clffile)

    clf = Regression(filename)
    ti.calc_ret_index()
    ret = ti.stock['ret_index']
    base = ti.stock_raw['Adj Close'][0]

    train_X, train_y = clf.train(ret, regression_type="Ridge")

    test_y = clf.predict(ret, base)

    expected = 19177.97
    r = round(test_y[0], 2)
    eq_(r, expected)

    if os.path.exists(clffile):
        os.remove(clffile)
コード例 #2
0
ファイル: test_regression.py プロジェクト: hnjun7802/finance
def test_ridge_regression():
    stock_d = testdata()
    ti = TechnicalIndicators(stock_d)

    filename = "test_N225_ridge.pickle"
    clffile = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "clf", filename)

    if os.path.exists(clffile):
        os.remove(clffile)

    clf = Regression(filename)
    ti.calc_ret_index()
    ret = ti.stock["ret_index"]
    base = ti.stock_raw["Adj Close"][0]

    train_X, train_y = clf.train(ret, regression_type="Ridge")

    test_y = clf.predict(ret, base)

    expected = 19177.97
    r = round(test_y[0], 2)
    eq_(r, expected)

    if os.path.exists(clffile):
        os.remove(clffile)
コード例 #3
0
def train():
    """
    Performs training and evaluation of Regression model.
    """
    print("Training started")
    # Set the random seeds for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Get number of units in each hidden layer
    if FLAGS.dnn_hidden_units:
        dnn_hidden_units = FLAGS.dnn_hidden_units.split(",")
        dnn_hidden_units = [
            int(dnn_hidden_unit_) for dnn_hidden_unit_ in dnn_hidden_units
        ]
    else:
        dnn_hidden_units = []

    # convert dropout percentages
    dropout_percentages = [
        int(perc) for perc in FLAGS.dropout_percentages.split(',')
    ]

    # check if length of dropout is equal to nr of hidden layers
    if len(dropout_percentages) != len(dnn_hidden_units):
        dropout_len = len(dropout_percentages)
        hidden_len = len(dnn_hidden_units)
        if dropout_len < hidden_len:
            for _ in range(hidden_len - dropout_len):
                dropout_percentages.append(0)
        else:
            dropout_percentages = dropout_percentages[:hidden_len]
    # use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device :", device)

    # extract all data and divide into train, valid and split dataloaders
    with open(os.path.join(FLAGS.data_dir, "dataset.p"), "rb") as f:
        dataset = pkl.load(f)

    len_all = len(dataset)

    train_len, valid_len = int(0.7 * len_all), int(0.15 * len_all)
    test_len = len_all - train_len - valid_len
    splits = [train_len, valid_len, test_len]
    train_data, valid_data, test_data = random_split(dataset, splits)

    train_dl = DataLoader(train_data, batch_size=64, shuffle=True)
    valid_dl = DataLoader(valid_data,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True)
    test_dl = DataLoader(test_data,
                         batch_size=64,
                         shuffle=True,
                         drop_last=True)

    # initialize MLP and loss function
    nn = Regression(5387, dnn_hidden_units, dropout_percentages, 1,
                    FLAGS.neg_slope, FLAGS.batchnorm).to(device)
    loss_function = torch.nn.MSELoss()

    # initialize optimizer
    if FLAGS.optimizer == "SGD":
        optimizer = torch.optim.SGD(nn.parameters(),
                                    lr=FLAGS.learning_rate,
                                    weight_decay=FLAGS.weightdecay,
                                    momentum=FLAGS.momentum)
    elif FLAGS.optimizer == "Adam":
        optimizer = torch.optim.Adam(nn.parameters(),
                                     lr=FLAGS.learning_rate,
                                     amsgrad=FLAGS.amsgrad,
                                     weight_decay=FLAGS.weightdecay)
    elif FLAGS.optimizer == "AdamW":
        optimizer = torch.optim.AdamW(nn.parameters(),
                                      lr=FLAGS.learning_rate,
                                      amsgrad=FLAGS.amsgrad,
                                      weight_decay=FLAGS.weightdecay)
    elif FLAGS.optimizer == "RMSprop":
        optimizer = torch.optim.RMSprop(nn.parameters(),
                                        lr=FLAGS.learning_rate,
                                        weight_decay=FLAGS.weightdecay,
                                        momentum=FLAGS.momentum)

    # initialization for plotting and metrics
    training_losses = []
    valid_losses = []

    # construct name for saving models and figures
    variables_string = f"{FLAGS.optimizer}_{FLAGS.learning_rate}_{FLAGS.weightdecay}_{FLAGS.dnn_hidden_units}_{FLAGS.dropout_percentages}_{FLAGS.batchnorm}_{FLAGS.nr_epochs}"

    # training loop
    for epoch in range(FLAGS.nr_epochs):

        print(f"\nEpoch: {epoch}")
        batch_losses = []
        nn.train()

        for batch, (x, y) in enumerate(train_dl):

            # append label to batch
            print("y", y.shape)
            onehot_y = torch.nn.functional.one_hot(y.squeeze().to(torch.int64),
                                                   num_classes=11)
            print("onehot", onehot_y.shape)
            x = torch.cat((x.reshape(x.shape[0], -1), onehot_y), 1)

            # squeeze the input, and put on device
            x = x.reshape(x.shape[0], -1).to(device)
            y = y.reshape(y.shape[0], -1).to(device)

            optimizer.zero_grad()

            # forward pass
            pred = nn(x).to(device)

            # compute loss and backpropagate
            loss = loss_function(pred, y)
            loss.backward()

            # update the weights
            optimizer.step()

            # save training loss
            batch_losses.append(loss.item())

        avg_epoch_loss = np.mean(batch_losses)
        training_losses.append(avg_epoch_loss)
        print(
            f"Average batch loss (epoch {epoch}: {avg_epoch_loss} ({len(batch_losses)} batches)."
        )

        # get loss on validation set and evaluate
        valid_losses.append(eval_on_test(nn, loss_function, valid_dl, device))
        torch.save(nn.state_dict(), f"Models/Regression_{variables_string}.pt")

    # compute loss and accuracy on the test set
    test_loss = eval_on_test(nn, loss_function, test_dl, device)
    print(f"Loss on test set: {test_loss}")

    plotting(training_losses, valid_losses, test_loss, variables_string)
コード例 #4
0
ファイル: analysis.py プロジェクト: arippbbc/finance
    def run(self):
        io = FileIO()
        will_update = self.update

        if self.csvfile:
            stock_tse = io.read_from_csv(self.code, self.csvfile)

            msg = "".join([
                "Read data from csv: ", self.code, " Records: ",
                str(len(stock_tse))
            ])
            print(msg)

            if self.update and len(stock_tse) > 0:
                index = pd.date_range(start=stock_tse.index[-1],
                                      periods=2,
                                      freq='B')
                ts = pd.Series(None, index=index)
                next_day = ts.index[1]
                t = next_day.strftime('%Y-%m-%d')
                newdata = io.read_data(self.code, start=t, end=self.end)

                msg = "".join([
                    "Read data from web: ", self.code, " New records: ",
                    str(len(newdata))
                ])
                print(msg)
                if len(newdata) < 1:
                    will_update = False
                else:
                    print(newdata.ix[-1, :])

                stock_tse = stock_tse.combine_first(newdata)
                io.save_data(stock_tse, self.code, 'stock_')
        else:
            stock_tse = io.read_data(self.code, start=self.start, end=self.end)

            msg = "".join([
                "Read data from web: ", self.code, " Records: ",
                str(len(stock_tse))
            ])
            print(msg)

        if stock_tse.empty:
            msg = "".join(["Data empty: ", self.code])
            print(msg)
            return None

        if not self.csvfile:
            io.save_data(stock_tse, self.code, 'stock_')

        try:
            stock_d = stock_tse.asfreq('B').dropna()[self.days:]

            ti = TechnicalIndicators(stock_d)

            ti.calc_sma()
            ti.calc_sma(timeperiod=5)
            ti.calc_sma(timeperiod=25)
            ti.calc_sma(timeperiod=50)
            ti.calc_sma(timeperiod=75)
            ewma = ti.calc_ewma(span=5)
            ewma = ti.calc_ewma(span=25)
            ewma = ti.calc_ewma(span=50)
            ewma = ti.calc_ewma(span=75)
            bbands = ti.calc_bbands()
            sar = ti.calc_sar()
            draw = Draw(self.code, self.fullname)

            ret = ti.calc_ret_index()
            ti.calc_vol(ret['ret_index'])
            rsi = ti.calc_rsi(timeperiod=9)
            rsi = ti.calc_rsi(timeperiod=14)
            mfi = ti.calc_mfi()
            roc = ti.calc_roc(timeperiod=10)
            roc = ti.calc_roc(timeperiod=25)
            roc = ti.calc_roc(timeperiod=50)
            roc = ti.calc_roc(timeperiod=75)
            roc = ti.calc_roc(timeperiod=150)
            ti.calc_cci()
            ultosc = ti.calc_ultosc()
            stoch = ti.calc_stoch()
            ti.calc_stochf()
            ti.calc_macd()
            willr = ti.calc_willr()
            ti.calc_momentum(timeperiod=10)
            ti.calc_momentum(timeperiod=25)
            tr = ti.calc_tr()
            ti.calc_atr()
            ti.calc_natr()
            vr = ti.calc_volume_rate()

            ret_index = ti.stock['ret_index']
            clf = Classifier(self.clffile)
            train_X, train_y = clf.train(ret_index, will_update)
            msg = "".join(["Train Records: ", str(len(train_y))])
            print(msg)
            clf_result = clf.classify(ret_index)[0]
            msg = "".join(["Classified: ", str(clf_result)])
            print(msg)
            ti.stock.ix[-1, 'classified'] = clf_result

            reg = Regression(self.regfile, alpha=1, regression_type="Ridge")
            train_X, train_y = reg.train(ret_index, will_update)
            msg = "".join(["Train Records: ", str(len(train_y))])
            base = ti.stock_raw['Adj Close'][0]
            reg_result = int(reg.predict(ret_index, base)[0])
            msg = "".join(["Predicted: ", str(reg_result)])
            print(msg)
            ti.stock.ix[-1, 'predicted'] = reg_result

            if len(self.reference) > 0:
                ti.calc_rolling_corr(self.reference)
                ref = ti.stock['rolling_corr']
            else:
                ref = []

            io.save_data(io.merge_df(stock_d, ti.stock), self.code, 'ti_')

            draw.plot(stock_d,
                      ewma,
                      bbands,
                      sar,
                      rsi,
                      roc,
                      mfi,
                      ultosc,
                      willr,
                      stoch,
                      tr,
                      vr,
                      clf_result,
                      reg_result,
                      ref,
                      axis=self.axis,
                      complexity=self.complexity)

            return ti

        except (ValueError, KeyError):
            msg = "".join(["Error occured in ", self.code])
            print(msg)
            return None
コード例 #5
0
    def run(self):
        io = FileIO()
        will_update = self.update

        self.logger.info("".join(["Start Analysis: ", self.code]))

        if self.csvfile:
            stock_tse = io.read_from_csv(self.code, self.csvfile)

            self.logger.info("".join([
                "Read data from csv: ", self.code, " Records: ",
                str(len(stock_tse))
            ]))

            if self.update and len(stock_tse) > 0:
                index = pd.date_range(start=stock_tse.index[-1],
                                      periods=2,
                                      freq='B')
                ts = pd.Series(None, index=index)
                next_day = ts.index[1]
                t = next_day.strftime('%Y-%m-%d')
                newdata = io.read_data(self.code, start=t, end=self.end)

                self.logger.info("".join([
                    "Read data from web: ", self.code, " New records: ",
                    str(len(newdata))
                ]))

                if len(newdata) < 1:
                    will_update = False
                else:
                    print(newdata.ix[-1, :])

                stock_tse = stock_tse.combine_first(newdata)
                io.save_data(stock_tse, self.code, 'stock_')
        else:
            stock_tse = io.read_data(self.code, start=self.start, end=self.end)

            self.logger.info("".join([
                "Read data from web: ", self.code, " Records: ",
                str(len(stock_tse))
            ]))

        if stock_tse.empty:

            self.logger.warn("".join(["Data empty: ", self.code]))

            return None

        if not self.csvfile:
            io.save_data(stock_tse, self.code, 'stock_')

        try:
            stock_d = stock_tse.asfreq('B').dropna()[self.minus_days:]

            ti = TechnicalIndicators(stock_d)

            ti.calc_sma()
            ti.calc_sma(timeperiod=5)
            ti.calc_sma(timeperiod=25)
            ti.calc_sma(timeperiod=50)
            ti.calc_sma(timeperiod=75)
            ti.calc_sma(timeperiod=200)
            ewma = ti.calc_ewma(span=5)
            ewma = ti.calc_ewma(span=25)
            ewma = ti.calc_ewma(span=50)
            ewma = ti.calc_ewma(span=75)
            ewma = ti.calc_ewma(span=200)
            bbands = ti.calc_bbands()
            sar = ti.calc_sar()
            draw = Draw(self.code, self.fullname)

            ret = ti.calc_ret_index()
            ti.calc_vol(ret['ret_index'])
            rsi = ti.calc_rsi(timeperiod=9)
            rsi = ti.calc_rsi(timeperiod=14)
            mfi = ti.calc_mfi()
            roc = ti.calc_roc(timeperiod=10)
            roc = ti.calc_roc(timeperiod=25)
            roc = ti.calc_roc(timeperiod=50)
            roc = ti.calc_roc(timeperiod=75)
            roc = ti.calc_roc(timeperiod=150)
            ti.calc_cci()
            ultosc = ti.calc_ultosc()
            stoch = ti.calc_stoch()
            ti.calc_stochf()
            ti.calc_macd()
            willr = ti.calc_willr()
            ti.calc_momentum(timeperiod=10)
            ti.calc_momentum(timeperiod=25)
            tr = ti.calc_tr()
            ti.calc_atr()
            ti.calc_natr()
            vr = ti.calc_volume_rate()

            ret_index = ti.stock['ret_index']
            clf = Classifier(self.clffile)
            train_X, train_y = clf.train(ret_index, will_update)

            self.logger.info("".join(
                ["Classifier Train Records: ",
                 str(len(train_y))]))

            clf_result = clf.classify(ret_index)[0]

            self.logger.info("".join(["Classified: ", str(clf_result)]))

            ti.stock.ix[-1, 'classified'] = clf_result

            reg = Regression(self.regfile, alpha=1, regression_type="Ridge")
            train_X, train_y = reg.train(ret_index, will_update)

            self.logger.info("".join(
                ["Regression Train Records: ",
                 str(len(train_y))]))

            base = ti.stock_raw['Adj Close'][0]
            reg_result = int(reg.predict(ret_index, base)[0])

            self.logger.info("".join(["Predicted: ", str(reg_result)]))

            ti.stock.ix[-1, 'predicted'] = reg_result

            if will_update is True:
                io.save_data(io.merge_df(stock_d, ti.stock), self.code, 'ti_')

            if self.minus_days < -300:
                _prefix = 'long'
            elif self.minus_days >= -60:
                _prefix = 'short'
            else:
                _prefix = 'chart'

            draw.plot(stock_d,
                      _prefix,
                      ewma,
                      bbands,
                      sar,
                      rsi,
                      roc,
                      mfi,
                      ultosc,
                      willr,
                      stoch,
                      tr,
                      vr,
                      clf_result,
                      reg_result,
                      axis=self.axis,
                      complexity=self.complexity)

            self.logger.info("".join(["Finish Analysis: ", self.code]))

            return ti

        except (ValueError, KeyError) as e:
            self.logger.error("".join(
                ["Error occured in ", self.code, " at analysis.py"]))
            self.logger.error("".join(['ErrorType: ', str(type(e))]))
            self.logger.error("".join(['ErrorMessage: ', str(e)]))
            return None
コード例 #6
0
ファイル: analysis.py プロジェクト: hnjun7802/finance
    def run(self):
        io = FileIO()
        will_update = self.update

        if self.csvfile:
            stock_tse = io.read_from_csv(self.code, self.csvfile)

            msg = "".join(["Read data from csv: ", self.code, " Records: ", str(len(stock_tse))])
            print(msg)

            if self.update and len(stock_tse) > 0:
                index = pd.date_range(start=stock_tse.index[-1], periods=2, freq="B")
                ts = pd.Series(None, index=index)
                next_day = ts.index[1]
                t = next_day.strftime("%Y-%m-%d")
                newdata = io.read_data(self.code, start=t, end=self.end)

                msg = "".join(["Read data from web: ", self.code, " New records: ", str(len(newdata))])
                print(msg)
                if len(newdata) < 1:
                    will_update = False
                else:
                    print(newdata.ix[-1, :])

                stock_tse = stock_tse.combine_first(newdata)
                io.save_data(stock_tse, self.code, "stock_")
        else:
            stock_tse = io.read_data(self.code, start=self.start, end=self.end)

            msg = "".join(["Read data from web: ", self.code, " Records: ", str(len(stock_tse))])
            print(msg)

        if stock_tse.empty:
            msg = "".join(["Data empty: ", self.code])
            print(msg)
            return None

        if not self.csvfile:
            io.save_data(stock_tse, self.code, "stock_")

        try:
            stock_d = stock_tse.asfreq("B").dropna()[self.days :]

            ti = TechnicalIndicators(stock_d)

            ti.calc_sma()
            ti.calc_sma(timeperiod=5)
            ti.calc_sma(timeperiod=25)
            ti.calc_sma(timeperiod=50)
            ti.calc_sma(timeperiod=75)
            ewma = ti.calc_ewma(span=5)
            ewma = ti.calc_ewma(span=25)
            ewma = ti.calc_ewma(span=50)
            ewma = ti.calc_ewma(span=75)
            bbands = ti.calc_bbands()
            sar = ti.calc_sar()
            draw = Draw(self.code, self.name)

            ret = ti.calc_ret_index()
            ti.calc_vol(ret["ret_index"])
            rsi = ti.calc_rsi(timeperiod=9)
            rsi = ti.calc_rsi(timeperiod=14)
            mfi = ti.calc_mfi()
            roc = ti.calc_roc(timeperiod=10)
            roc = ti.calc_roc(timeperiod=25)
            roc = ti.calc_roc(timeperiod=50)
            roc = ti.calc_roc(timeperiod=75)
            roc = ti.calc_roc(timeperiod=150)
            ti.calc_cci()
            ultosc = ti.calc_ultosc()
            stoch = ti.calc_stoch()
            ti.calc_stochf()
            ti.calc_macd()
            willr = ti.calc_willr()
            ti.calc_momentum(timeperiod=10)
            ti.calc_momentum(timeperiod=25)
            tr = ti.calc_tr()
            ti.calc_atr()
            ti.calc_natr()
            vr = ti.calc_volume_rate()

            ret_index = ti.stock["ret_index"]
            clf = Classifier(self.clffile)
            train_X, train_y = clf.train(ret_index, will_update)
            msg = "".join(["Train Records: ", str(len(train_y))])
            print(msg)
            clf_result = clf.classify(ret_index)[0]
            msg = "".join(["Classified: ", str(clf_result)])
            print(msg)
            ti.stock.ix[-1, "classified"] = clf_result

            reg = Regression(self.regfile, alpha=1, regression_type="Ridge")
            train_X, train_y = reg.train(ret_index, will_update)
            msg = "".join(["Train Records: ", str(len(train_y))])
            base = ti.stock_raw["Adj Close"][0]
            reg_result = int(reg.predict(ret_index, base)[0])
            msg = "".join(["Predicted: ", str(reg_result)])
            print(msg)
            ti.stock.ix[-1, "predicted"] = reg_result

            if len(self.reference) > 0:
                ti.calc_rolling_corr(self.reference)
                ref = ti.stock["rolling_corr"]
            else:
                ref = []

            io.save_data(io.merge_df(stock_d, ti.stock), self.code, "ti_")

            draw.plot(
                stock_d,
                ewma,
                bbands,
                sar,
                rsi,
                roc,
                mfi,
                ultosc,
                willr,
                stoch,
                tr,
                vr,
                clf_result,
                reg_result,
                ref,
                axis=self.axis,
                complexity=self.complexity,
            )

            return ti

        except (ValueError, KeyError):
            msg = "".join(["Error occured in ", self.code])
            print(msg)
            return None
コード例 #7
0
from regression import Regression, panda_to_numpy

data_path = "./hollywood.xls"
data = pd.read_excel(data_path)

x_train = data['X2']
y_train = data['X1']
z_train = data['X3']

# y must be an nX1 array
# x myst be an nXm array where m is the number of different variables for the reggression
# regression() returns an 1Xm+1 array wich are the weights +1 is for the constant
x = panda_to_numpy(data, 'X2', 'X3')
y = data['X1']
model = Regression(y, x)
weights = model.train(epochs=200, a=0.0001, print_loss=True)
prediction = model.predict(8, 15)
print(prediction)

# plot the results
ones = np.ones(len(x_train))
x_normalized = np.linspace(x_train.min(), x_train.max(), len(x_train))
z_normalized = np.linspace(z_train.min(), z_train.max(), len(z_train))
x_pred = np.column_stack((x_normalized, z_normalized, ones))
y_pred = weights @ x_pred.T

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_title("Prediction of hollywood movie revenue")
ax.set_zlabel("Revenue")
ax.set_xlabel("Cost of production")
コード例 #8
0
import matplotlib.animation as animation
from regression import Regression, panda_to_numpy

data_path = "./biocarbonate.xls"
data = pd.read_excel(data_path)

x_train = data['X']
y_train = data['Y']

# y must be an nX1 array
# x myst be an nXm array where m is the number of different variables for the reggression
# regression() returns an 1Xm+1 array wich are the weights +1 is for the constant
x = panda_to_numpy(data, 'X')
y = data['Y']
model = Regression(y, x)
weights = model.train(epochs=200000, a=0.00005, print_loss=False)
prediction = model.predict(8)
print(prediction)

# plot the results
ones = np.ones(len(x_train))
x_pred = np.column_stack((np.linspace(x_train.min(), x_train.max(),
                                      len(x_train)), ones))
y_pred = weights @ x_pred.T

plt.title("Prediction of bicarbonate")
plt.xlabel("ph")
plt.ylabel("biocarbonates ppm")
plt.plot(x_train, y_train, "ro", x_pred, y_pred, "g--")
plt.axis([x_train.min(), x_train.max(), y_train.min(), y_train.max()])
plt.show()
コード例 #9
0
ファイル: train_regression.py プロジェクト: Tom-Lotze/IR2_5
def train():
    """
    Performs training and evaluation of Regression model.
    """
    # Set the random seeds for reproducibility
    np.random.seed(10)
    torch.manual_seed(10)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Get number of units in each hidden layer
    if FLAGS.dnn_hidden_units:
        dnn_hidden_units = FLAGS.dnn_hidden_units.split(",")
        dnn_hidden_units = [
            int(dnn_hidden_unit_) for dnn_hidden_unit_ in dnn_hidden_units
        ]
    else:
        dnn_hidden_units = []

    # convert dropout percentages
    dropout_probs = [float(prob) for prob in FLAGS.dropout_probs.split(',')]

    # check if length of dropout is equal to nr of hidden layers
    if len(dropout_probs) != len(dnn_hidden_units):
        dropout_len = len(dropout_probs)
        hidden_len = len(dnn_hidden_units)
        if dropout_len < hidden_len:
            for _ in range(hidden_len - dropout_len):
                dropout_probs.append(0)
        else:
            dropout_probs = dropout_probs[:hidden_len]
    # use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device :", device)

    # extract all data and divide into train, valid and split dataloaders
    dataset_filename = f"dataset_filename=MIMICS-Click.tsv_expanded=False_balance=True_impression={FLAGS.impression}_reduced_classes={FLAGS.reduced_classes}_embedder={FLAGS.embedder}.p"
    with open(os.path.join(FLAGS.data_dir, dataset_filename), "rb") as f:
        dataset = pkl.load(f)

    len_all = len(dataset)

    train_len, valid_len = int(0.7 * len_all), int(0.15 * len_all)
    test_len = len_all - train_len - valid_len
    splits = [train_len, valid_len, test_len]
    train_data, valid_data, test_data = random_split(dataset, splits)

    train_dl = DataLoader(train_data,
                          batch_size=FLAGS.batch_size,
                          shuffle=True,
                          drop_last=True)
    valid_dl = DataLoader(valid_data,
                          batch_size=FLAGS.batch_size,
                          shuffle=True,
                          drop_last=True)
    test_dl = DataLoader(test_data,
                         batch_size=FLAGS.batch_size,
                         shuffle=True,
                         drop_last=True)

    with open(f"{FLAGS.data_dir}/test_dl.pt", "wb") as f:
        pkl.dump(test_dl, f)

    # initialize MLP and loss function
    input_size = iter(train_dl).next()[0].shape[1]  # 5376 for BERT embeddings
    nn = Regression(input_size, dnn_hidden_units, dropout_probs, 1,
                    FLAGS.neg_slope, FLAGS.batchnorm).to(device)
    loss_function = torch.nn.MSELoss()

    if FLAGS.verbose:
        print(f"neural net:\n {[param.data for param in nn.parameters()]}")

    # initialize optimizer
    if FLAGS.optimizer == "SGD":
        optimizer = torch.optim.SGD(nn.parameters(),
                                    lr=FLAGS.learning_rate,
                                    weight_decay=FLAGS.weightdecay,
                                    momentum=FLAGS.momentum)
    elif FLAGS.optimizer == "Adam":
        optimizer = torch.optim.Adam(nn.parameters(),
                                     lr=FLAGS.learning_rate,
                                     amsgrad=FLAGS.amsgrad,
                                     weight_decay=FLAGS.weightdecay)
    elif FLAGS.optimizer == "AdamW":
        optimizer = torch.optim.AdamW(nn.parameters(),
                                      lr=FLAGS.learning_rate,
                                      amsgrad=FLAGS.amsgrad,
                                      weight_decay=FLAGS.weightdecay)
    elif FLAGS.optimizer == "RMSprop":
        optimizer = torch.optim.RMSprop(nn.parameters(),
                                        lr=FLAGS.learning_rate,
                                        weight_decay=FLAGS.weightdecay,
                                        momentum=FLAGS.momentum)

    # initialization for plotting and metrics
    training_losses = []
    valid_losses = []

    initial_train_loss = eval_on_test(nn, loss_function, train_dl, device)
    training_losses.append(initial_train_loss)
    initial_valid_loss = eval_on_test(nn, loss_function, valid_dl, device)
    valid_losses.append(initial_valid_loss)

    # construct name for saving models and figures
    variables_string = f"regression_{FLAGS.embedder}_{FLAGS.impression}_{FLAGS.reduced_classes}_{FLAGS.optimizer}_{FLAGS.learning_rate}_{FLAGS.weightdecay}_{FLAGS.momentum}_{FLAGS.dnn_hidden_units}_{FLAGS.dropout_probs}_{FLAGS.batchnorm}_{FLAGS.nr_epochs}"

    overall_batch = 0
    min_valid_loss = 10000

    # training loop
    for epoch in range(FLAGS.nr_epochs):

        print(f"\nEpoch: {epoch}")

        for batch, (x, y) in enumerate(train_dl):
            nn.train()

            # squeeze the input, and put on device
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()

            # forward pass
            pred = nn(x).to(device)

            # compute loss and backpropagate
            loss = loss_function(pred, y)
            loss.backward()

            # update the weights
            optimizer.step()

            # save training loss
            training_losses.append(loss.item())

            # print(f"batch loss ({batch}): {loss.item()}")

            # get loss on validation set and evaluate
            if overall_batch % FLAGS.eval_freq == 0 and overall_batch != 0:
                valid_loss = eval_on_test(nn, loss_function, valid_dl, device)
                valid_losses.append(valid_loss)
                print(
                    f"Training loss: {loss.item()} / Valid loss: {valid_loss}")
                if valid_loss < min_valid_loss:
                    print(
                        f"Model is saved in epoch {epoch}, overall batch: {overall_batch}"
                    )
                    torch.save(nn.state_dict(),
                               f"Models/Regression_{variables_string}.pt")
                    min_valid_loss = valid_loss
                    optimal_batch = overall_batch

            overall_batch += 1

    # Load the optimal model (with loweest validation loss, and evaluate on test set)
    optimal_nn = Regression(input_size, dnn_hidden_units, dropout_probs, 1,
                            FLAGS.neg_slope, FLAGS.batchnorm).to(device)
    optimal_nn.load_state_dict(
        torch.load(f"Models/Regression_{variables_string}.pt"))

    test_loss, test_pred, test_true = eval_on_test(optimal_nn,
                                                   loss_function,
                                                   test_dl,
                                                   device,
                                                   verbose=FLAGS.verbose,
                                                   return_preds=True)

    # save the test predictions of the regressor
    with open(
            f"Predictions/regression_test_preds{FLAGS.embedder}_{FLAGS.reduced_classes}_{FLAGS.impression}.pt",
            "wb") as f:
        pkl.dump(test_pred, f)

    print(
        f"Loss on test set of optimal model (batch {optimal_batch}): {test_loss}"
    )

    significance_testing(test_pred, test_true, loss_function, FLAGS)

    if FLAGS.plotting:
        plotting(training_losses, valid_losses, test_loss, variables_string,
                 optimal_batch, FLAGS)