plt.ylabel('value')
plt.title('Triple scores - +ve and -ve')
plt.legend()
plt.show()
plt.plot(loss, color='red')
plt.xlabel('loss')
plt.ylabel('value')
plt.title('Margin Ranking Loss')
plt.show()

# Create an Optimizer
optimizer = Optimizer(model=mod,
                      training_rdd=train_data_rdd,
                      criterion=MarginRankingCriterion(),
                      optim_method=None,
                      end_trigger=MaxEpoch(100),
                      batch_size=10)

appname = 'DISTMULT' + dt.datetime.now().strftime("%Y%m%d-%H%M%S")
train_summary = TrainSummary(log_dir='bigdl_summaries', app_name=app_name)

train_summary.set_summary_trigger("Parameters", SeveralIteration(50))
val_summary = ValidationSummary(log_dir='bigdl_summaries', app_name=app_name)

optimizer.set_train_summary(train_summary)
optimizer.set_val_summary(val_summary)

# Validation
batch_size = 10
validation_results = mod.evaluate(train_data_rdd, batch_size, [Loss()])
print(validation_results[0])
示例#2
0
                               input_shape=input_shape)(both_input)

encode_left = both_feature.index_select(1, 0)
encode_right = both_feature.index_select(1, 1)

distance = autograd.abs(encode_left - encode_right)
predict = Dense(output_dim=NUM_CLASS_LABEL,
                activation="sigmoid",
                W_regularizer=L2Regularizer(args.penalty_rate))(distance)

siamese_net = Model(input=both_input, output=predict)

# 声明优化器, 训练并测试模型.
optimizer = Optimizer(model=siamese_net,
                      training_rdd=train_rdd,
                      optim_method=Adam(args.learning_rate),
                      criterion=CrossEntropyCriterion(),
                      end_trigger=MaxEpoch(args.num_epoch),
                      batch_size=args.batch_size)
optimizer.set_validation(batch_size=args.batch_size,
                         val_rdd=test_rdd,
                         trigger=EveryEpoch(),
                         val_method=[Top1Accuracy()])

# 设置训练日志, 可用 TensorBoard 查询.
app_name = "logs"
optimizer.set_train_summary(TrainSummary(log_dir=".", app_name=app_name))
optimizer.set_val_summary(ValidationSummary(log_dir=".", app_name=app_name))

optimizer.optimize()