def main(_): model_dir = util.get_model_dir(conf, ['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step', 'is_train', 'random_seed', 'log_level', 'display', 'runtime_base_dir', 'occlude_start_row', 'num_generated_images']) util.preprocess_conf(conf) validate_parameters(conf) data = 'mnist' if conf.data == 'color-mnist' else conf.data DATA_DIR = os.path.join(conf.runtime_base_dir, conf.data_dir, data) SAMPLE_DIR = os.path.join(conf.runtime_base_dir, conf.sample_dir, conf.data, model_dir) util.check_and_create_dir(DATA_DIR) util.check_and_create_dir(SAMPLE_DIR) dataset = get_dataset(DATA_DIR, conf.q_levels) with tf.Session() as sess: network = Network(sess, conf, dataset.height, dataset.width, dataset.channels) stat = Statistic(sess, conf.data, conf.runtime_base_dir, model_dir, tf.trainable_variables()) stat.load_model() if conf.is_train: train(dataset, network, stat, SAMPLE_DIR) else: generate(network, dataset.height, dataset.width, SAMPLE_DIR)
def check_n_create_output_dir(output_dir): """ Check and create output directory """ source_out_dir = os.path.join(output_dir, "source") target_out_dir = os.path.join(output_dir, "target") check_and_create_dir(source_out_dir) check_and_create_dir(target_out_dir) return (source_out_dir, target_out_dir)
def train(self, x_path_dir, y_path_dir, epochs, train_steps, learning_rate, epochs_to_reduce_lr, reduce_lr, output_model, output_log, b_size): """ Train data """ # Check output directory # suffix for clafification on type if output_model: output_model+="AE" else: output_model+="MultiCNN" check_and_create_dir(output_model) # Load data x_filenames = extract_image_path([x_path_dir]) y_filenames = extract_image_path([y_path_dir]) # Scalar tf.summary.scalar('Learning rate', self.learning_rate) tf.summary.scalar('MSE', self.mse) tf.summary.scalar('MS SSIM', self.ssim) tf.summary.scalar('Loss', self.cost) tf.summary.image('BSE', self.Y) tf.summary.image('Ground truth', self.Y_clear) merged = tf.summary.merge_all() sess, saver = self.init_session() writer = tf.summary.FileWriter(output_log, sess.graph) l_rate = learning_rate try: for epoch_i in range(epochs): if ((epoch_i + 1) % epochs_to_reduce_lr) == 0: l_rate = l_rate * (1 - reduce_lr) if self.verbose: print("\n------------ Epoch : ",epoch_i+1) print("Current learning rate {}".format(l_rate)) # Training steps for i in range(train_steps): if self.verbose: print_train_steps(i+1, train_steps) x_batch, y_batch = get_batch(b_size, self.image_size, x_filenames, y_filenames) sess.run(self.optimizer, feed_dict={ self.X: x_batch, self.Y_clear: y_batch, self.learning_rate: l_rate, self.batch_size: b_size }) if i % 50 == 0: summary = sess.run(merged, {self.X: x_batch, self.Y_clear: y_batch, self.learning_rate: l_rate, self.batch_size: b_size}) writer.add_summary(summary, i+ epoch_i*train_steps) if self.verbose: print("\nSave model to {}".format(output_model)) saver.save(sess, output_model, global_step=(epoch_i+1)*train_steps) except KeyboardInterrupt: saver.save(sess, output_model)
def main(): parser = argparse.ArgumentParser(description="Command line tool to manage a public key infrastructure.") parser.add_argument('-c', '--create', help="Creates client certificate and keys.", action='store_true', default=False) parser.add_argument('-d', '--default-ca', help="Creates the root CA with default parameters.", action='store_true', default=False) parser.add_argument('-g', '--generate-ca', help="Generates the root CA and ask for values.", action='store_true', default=False) parser.add_argument('-l', '--list', help="List all certificates and keys that are present in the pki.", action='store_true', default=False) parser.add_argument('-r', '--default-req', help="Creates client certificate and keys with default attributes.", action='store_true', default=False) args = parser.parse_args() check_and_create_dir(PKI_DIR) if args.generate_ca or args.default_ca: if exists_and_isfile(CA_CERT_FULLPATH) and exists_and_isfile(CA_KEY_FULLPATH): print('There is an existing CA') return if args.default_ca: default_attrs = ['PL', 'Poneyland', 'kichland', 'Poney Corp', 'ROOT-CA', 'ROOT-CA Poney CORP'] else: default_attrs = get_attr() k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 2048) my_cert = cacert_req(default_attrs) my_cert.set_pubkey(k) my_cert.sign(k, 'sha256') with open(CA_CERT_FULLPATH, 'wb') as fd: fd.write(crypto.dump_certificate(crypto.FILETYPE_PEM, my_cert)) with open(CA_KEY_FULLPATH, 'wb') as fd: fd.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k)) elif args.create or args.default_req: if args.default_req: cli_attrs = ['PL', 'Poneyland', 'kichland', 'Poney Corp', 'info', 'poney_test'] else: cli_attrs = get_attr() # load ROOT CA private key ca_key_fd = open(CA_KEY_FULLPATH, 'rb').read() ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key_fd ) ca_cert_fd = open(CA_CERT_FULLPATH, 'rb').read() ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_cert_fd ) # create csr cli_pkey = crypto.PKey() cli_pkey.generate_key(crypto.TYPE_RSA, 2048) cli_csr = csr_req( cli_attrs ) cli_csr.set_pubkey(cli_pkey) cli_csr.sign( cli_pkey, 'sha256') cli_cert = crypto.X509() cli_cert.set_issuer(ca_cert.get_subject()) cli_cert.gmtime_adj_notBefore(0) cli_cert.gmtime_adj_notAfter(315360000) cli_cert.set_serial_number(cli_csr.get_serial_number()) cli_cert.set_subject(cli_csr.get_subject()) cli_cert.set_pubkey(cli_csr.get_pubkey()) cli_cert.sign( ca_key, 'sha256') with open('key/client.crt', 'wb') as fd: fd.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cli_cert)) with open('key/client.key', 'wb') as fd: fd.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, cli_pkey))
freeze_support() try: if len(sys.argv) > 1: all_config_filename = sys.argv[1] with open(all_config_filename, "r", encoding="UTF-8") as f: all_config = json.load(f) else: with open("config.json", "r", encoding="UTF-8") as f: all_config = json.load(f) except Exception as e: print("解析配置文件时出现错误,请检查配置文件!") print("错误详情:" + str(e)) os.system('pause') utils.check_and_create_dir(all_config['root']['logger']['log_path']) logging.basicConfig( level=utils.get_log_level(all_config), format= '%(asctime)s %(thread)d %(threadName)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', datefmt='%a, %d %b %Y %H:%M:%S', handlers=[ logging.FileHandler(os.path.join( all_config['root']['logger']['log_path'], "Main_" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.log'), "a", encoding="utf-8") ]) runner_list = []
utils.add_path("./ffmpeg/bin") try: if len(sys.argv) > 1: all_config_filename = sys.argv[1] with open(all_config_filename, "r", encoding="UTF-8") as f: all_config = json.load(f) else: with open("config.json", "r", encoding="UTF-8") as f: all_config = json.load(f) except Exception as e: print("解析配置文件时出现错误,请检查配置文件!") print("错误详情:" + str(e)) os.system('pause') utils.check_and_create_dir( all_config.get('root', {}).get('data_path', "./")) utils.check_and_create_dir( all_config.get('root', {}).get('logger', {}).get('log_path', './log')) logfile_name = "Main_" + datetime.datetime.now().strftime( '%Y-%m-%d_%H-%M-%S') + '.log' logging.basicConfig( level=utils.get_log_level(all_config), format= '%(asctime)s %(thread)d %(threadName)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', datefmt='%a, %d %b %Y %H:%M:%S', handlers=[ logging.FileHandler(os.path.join( all_config.get('root', {}).get('logger', {}).get('log_path', "./log"), logfile_name), "a",
u = Uploader(p.outputs_dir, p.splits_dir, config) u.upload(p.global_start) if __name__ == "__main__": root_config_filename = sys.argv[1] spec_config_filename = sys.argv[2] with open(root_config_filename, "r") as f: root_config = json.load(f) with open(spec_config_filename, "r") as f: spec_config = json.load(f) config = { 'root': root_config, 'spec': spec_config } utils.check_and_create_dir(config['root']['global_path']['data_path']) utils.check_and_create_dir(config['root']['logger']['log_path']) logging.basicConfig(level=utils.get_log_level(config), format='%(asctime)s %(thread)d %(threadName)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', datefmt='%a, %d %b %Y %H:%M:%S', filename=os.path.join(config['root']['logger']['log_path'], datetime.datetime.now( ).strftime('%Y-%m-%d_%H-%M-%S')+'.log'), filemode='a') utils.init_data_dirs(config['root']['global_path']['data_path']) bl = BiliLive(config) prev_live_status = False while True: if not prev_live_status and bl.live_status: print("开播啦~") prev_live_status = bl.live_status start = datetime.datetime.now()