def session_impl_test(json_file): solver_config = solver_parser_helper(seed=0, batchsize=16384, batchsize_eval=16384, model_file="", embedding_files=[], vvgpu=[[0, 1, 2, 3, 4, 5, 6, 7]], use_mixed_precision=True, scaler=1024, i64_input_key=False, use_algorithm_search=True, use_cuda_graph=True, repeat_dataset=True) lr_sch = get_learning_rate_scheduler(json_file) sess = Session(solver_config, json_file) sess.start_data_reading() for i in range(10000): lr = lr_sch.get_next() sess.set_learning_rate(lr) sess.train() if (i % 100 == 0): loss = sess.get_current_loss() print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss)) if (i % 1000 == 0 and i != 0): metrics = sess.evaluation() print("[HUGECTR][INFO] iter: {}, {}".format(i, metrics)) return
def model_oversubscriber_test(json_file, temp_dir): dataset = [("file_list." + str(i) + ".txt", "file_list." + str(i) + ".keyset") for i in range(5)] solver_config = solver_parser_helper(seed=0, batchsize=16384, batchsize_eval=16384, model_file="", embedding_files=[], vvgpu=[[0]], use_mixed_precision=False, scaler=1.0, i64_input_key=False, use_algorithm_search=True, use_cuda_graph=True, repeat_dataset=False) lr_sch = get_learning_rate_scheduler(json_file) sess = Session(solver_config, json_file, True, temp_dir) data_reader_train = sess.get_data_reader_train() data_reader_eval = sess.get_data_reader_eval() data_reader_eval.set_source("file_list.5.txt") model_oversubscriber = sess.get_model_oversubscriber() iteration = 0 for file_list, keyset_file in dataset: data_reader_train.set_source(file_list) model_oversubscriber.update(keyset_file) while True: lr = lr_sch.get_next() sess.set_learning_rate(lr) good = sess.train() if good == False: break if iteration % 100 == 0: sess.check_overflow() sess.copy_weights_for_evaluation() data_reader_eval = sess.get_data_reader_eval() good_eval = True j = 0 while good_eval: if j >= solver_config.max_eval_batches: break good_eval = sess.eval() j += 1 if good_eval == False: data_reader_eval.set_source() metrics = sess.get_eval_metrics() print("[HUGECTR][INFO] iter: {}, metrics: {}".format( iteration, metrics)) iteration += 1 print("[HUGECTR][INFO] trained with data in {}".format(file_list)) sess.download_params_to_files("./", iteration)
def session_impl_test(json_file): solver_config = solver_parser_helper(seed=0, batchsize=40960, batchsize_eval=40960, model_file="", embedding_files=[], vvgpu=[[0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7]], use_mixed_precision=False, scaler=1.0, i64_input_key=False, use_algorithm_search=True, use_cuda_graph=True, repeat_dataset=True) sess = Session(solver_config, json_file) sess.start_data_reading() lr_sch = get_learning_rate_scheduler(json_file) for i in range(2000): lr = lr_sch.get_next() sess.set_learning_rate(lr) sess.train() if (i % 200 == 0): loss = sess.get_current_loss() if (rank == 0): print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss)) if (i % 1000 == 0 and i != 0): sess.check_overflow() sess.copy_weights_for_evaluation() data_reader_eval = sess.get_data_reader_eval() for _ in range(solver_config.max_eval_batches): sess.eval() metrics = sess.get_eval_metrics() print("[HUGECTR][INFO] rank: {}, iter: {}, {}".format( rank, i, metrics)) return
def set_source_raw_test(json_file): train_data = "./train_data.bin" test_data = "./test_data.bin" solver_config = solver_parser_helper(seed=0, batchsize=16384, batchsize_eval=16384, max_eval_batches=5441, model_file="", embedding_files=[], vvgpu=[[0, 1, 2, 3, 4, 5, 6, 7]], use_mixed_precision=True, scaler=1024, i64_input_key=False, use_algorithm_search=True, use_cuda_graph=True, repeat_dataset=False) lr_sch = get_learning_rate_scheduler(json_file) sess = Session(solver_config, json_file) data_reader_train = sess.get_data_reader_train() data_reader_eval = sess.get_data_reader_eval() data_reader_eval.set_source(test_data) iteration = 1 for cnt in range(2): data_reader_train.set_source(train_data) print("[HUGECTR][INFO] round: {}".format(cnt), flush=True) while True: lr = lr_sch.get_next() sess.set_learning_rate(lr) good = sess.train() if good == False: break if iteration % 4000 == 0: sess.check_overflow() sess.copy_weights_for_evaluation() data_reader_eval = sess.get_data_reader_eval() good_eval = True j = 0 while good_eval: if j >= solver_config.max_eval_batches: break good_eval = sess.eval() j += 1 if good_eval == False: data_reader_eval.set_source() metrics = sess.get_eval_metrics() print("[HUGECTR][INFO] iter: {}, metrics: {}".format( iteration, metrics), flush=True) iteration += 1 print("[HUGECTR][INFO] trained with data in {}".format(train_data), flush=True)