def DBN_JIT(train_features, train_labels, test_features, test_labels, hidden_units=[20, 12, 12], num_epochs_LR=200): # training DBN model ################################################################################################# starttime = time.time() dbn_model = DBN(visible_units=train_features.shape[1], hidden_units=hidden_units, use_gpu=False) dbn_model.train_static(train_features, train_labels, num_epochs=10) # Finishing the training DBN model # print('---------------------Finishing the training DBN model---------------------') # using DBN model to construct features DBN_train_features, _ = dbn_model.forward(train_features) DBN_test_features, _ = dbn_model.forward(test_features) DBN_train_features = DBN_train_features.numpy() DBN_test_features = DBN_test_features.numpy() train_features = np.hstack((train_features, DBN_train_features)) test_features = np.hstack((test_features, DBN_test_features)) if len(train_labels.shape) == 1: num_classes = 1 else: num_classes = train_labels.shape[1] # lr_model = LR(input_size=hidden_units, num_classes=num_classes) lr_model = LR(input_size=train_features.shape[1], num_classes=num_classes) optimizer = torch.optim.Adam(lr_model.parameters(), lr=0.00001) steps = 0 batches_test = mini_batches(X=test_features, Y=test_labels) for epoch in range(1, num_epochs_LR + 1): # building batches for training model batches_train = mini_batches_update(X=train_features, Y=train_labels) for batch in batches_train: x_batch, y_batch = batch x_batch, y_batch = torch.tensor(x_batch).float(), torch.tensor(y_batch).float() optimizer.zero_grad() predict = lr_model.forward(x_batch) loss = nn.BCELoss() loss = loss(predict, y_batch) loss.backward() optimizer.step() # steps += 1 # if steps % 100 == 0: # print('\rEpoch: {} step: {} - loss: {:.6f}'.format(epoch, steps, loss.item())) endtime = time.time() dtime = endtime - starttime print("Train Time: %.8s s" % dtime) #显示到微秒 starttime = time.time() y_pred, lables = lr_model.predict(data=batches_test) endtime = time.time() dtime = endtime - starttime print("Eval Time: %.8s s" % dtime) #显示到微秒 return y_pred