def train(self, output_dir = '/opt/ml/model', hidden_dim=128, max_step=320000): self.check_parent_dir('.',self.train_output_key) dglke_train.main(['--dataset',self.kg_folder, #'--model_name','RotatE' '--gamma','19.9', '--lr', '0.25', '--max_step',str(max_step), '--log_interval',str(max_step//100), '--batch_size_eval','1000', '--hidden_dim', str(hidden_dim//2), # RotatE模型传入的是1/2 hidden_dim的 '-adv', '--regularization_coef','1.00E-09', '--gpu','0', '--double_ent', '--mix_cpu_gpu', '--save_path',self.train_output_key, '--data_path',self.kg_folder, '--format','udd_hrt', '--data_files',self.kg_entity_key,self.kg_relation_key,self.kg_dbpedia_key, '--neg_sample_size_eval','10000']) # dglke_train.main(['--dataset','kg', # #'--model_name','RotatE' # '--gamma','19.9', # '--lr', '0.25', # '--max_step',str(max_step), # '--log_interval',str(max_step//100), # '--batch_size_eval','1000', # '--hidden_dim', str(hidden_dim//2), # RotatE模型传入的是1/2 hidden_dim的 # '-adv', # '--regularization_coef','1.00E-09', # '--gpu','0', # '--double_ent', # '--mix_cpu_gpu', # '--save_path',output_dir, # '--data_path',self.kg_folder, # '--format','udd_hrt', # '--data_files','entities_dbpedia.dict','relations_dbpedia.dict','kg_dbpedia.txt', # '--neg_sample_size_eval','10000']) print("finish training!!") if self.train_output_key != None: print("upload to {}".format(self.train_output_key)) for name in glob.glob(os.path.join(self.train_output_key, '*.npy')): print("upload {}".format(name)) s3client.upload_file(name, self.train_output_key.split('/')[0], name.split('/')[-1])
def train(self, output_dir = '/opt/ml/model', hidden_dim=128, max_step=320000): dglke_train.main(['--dataset','kg', #'--model_name','RotatE' '--gamma','19.9', '--lr', '0.25', '--max_step',str(max_step), '--log_interval',str(max_step//100), '--batch_size_eval','1000', '--hidden_dim', str(hidden_dim//2), # RotatE模型传入的是1/2 hidden_dim的 '-adv', '--regularization_coef','1.00E-09', '--gpu','0', '--double_ent', '--mix_cpu_gpu', '--save_path',output_dir, '--data_path',self.kg_folder, '--format','udd_hrt', '--data_files','entities_dbpedia.dict','relations_dbpedia.dict','kg_dbpedia.txt', '--neg_sample_size_eval','10000'])
from dglke.train import main import sys import re if __name__ == '__main__': sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.exit(main())