Exemplo n.º 1
0
def main(_):
  model_dir = util.get_model_dir(conf, 
      ['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step',
       'is_train', 'random_seed', 'log_level', 'display', 'runtime_base_dir', 
       'occlude_start_row', 'num_generated_images'])
  util.preprocess_conf(conf)
  validate_parameters(conf)

  data = 'mnist' if conf.data == 'color-mnist' else conf.data 
  DATA_DIR = os.path.join(conf.runtime_base_dir, conf.data_dir, data)
  SAMPLE_DIR = os.path.join(conf.runtime_base_dir, conf.sample_dir, conf.data, model_dir)

  util.check_and_create_dir(DATA_DIR)
  util.check_and_create_dir(SAMPLE_DIR)
  
  dataset = get_dataset(DATA_DIR, conf.q_levels)

  with tf.Session() as sess:
    network = Network(sess, conf, dataset.height, dataset.width, dataset.channels)

    stat = Statistic(sess, conf.data, conf.runtime_base_dir, model_dir, tf.trainable_variables())
    stat.load_model()

    if conf.is_train:
      train(dataset, network, stat, SAMPLE_DIR)
    else:
      generate(network, dataset.height, dataset.width, SAMPLE_DIR)
Exemplo n.º 2
0
def check_n_create_output_dir(output_dir):
    """
    Check and create output directory
    """
    source_out_dir = os.path.join(output_dir, "source")
    target_out_dir = os.path.join(output_dir, "target")
    check_and_create_dir(source_out_dir)
    check_and_create_dir(target_out_dir)
    return (source_out_dir, target_out_dir)
Exemplo n.º 3
0
    def train(self, x_path_dir, y_path_dir, epochs, train_steps, learning_rate, epochs_to_reduce_lr, reduce_lr, output_model, output_log, b_size):
        """
        Train data
        """
        # Check output directory
        # suffix for clafification on type
        if output_model:
            output_model+="AE"
        else:
            output_model+="MultiCNN"

        check_and_create_dir(output_model)

        # Load data
        x_filenames = extract_image_path([x_path_dir])
        y_filenames = extract_image_path([y_path_dir])

        # Scalar
        tf.summary.scalar('Learning rate', self.learning_rate)
        tf.summary.scalar('MSE', self.mse)
        tf.summary.scalar('MS SSIM', self.ssim)
        tf.summary.scalar('Loss', self.cost)
        tf.summary.image('BSE', self.Y)
        tf.summary.image('Ground truth', self.Y_clear)
        merged = tf.summary.merge_all()

        sess, saver = self.init_session()
        writer = tf.summary.FileWriter(output_log, sess.graph)

        l_rate = learning_rate
        try:
            for epoch_i in range(epochs):
                if ((epoch_i + 1) % epochs_to_reduce_lr) == 0:
                    l_rate = l_rate * (1 - reduce_lr)
                if self.verbose:
                    print("\n------------ Epoch : ",epoch_i+1)
                    print("Current learning rate {}".format(l_rate))

                # Training steps
                for i in range(train_steps):
                    if self.verbose:
                        print_train_steps(i+1, train_steps)
                    x_batch, y_batch = get_batch(b_size, self.image_size, x_filenames, y_filenames)

                    sess.run(self.optimizer, feed_dict={ self.X: x_batch, self.Y_clear: y_batch, self.learning_rate: l_rate, self.batch_size: b_size })

                    if i % 50 == 0:
                        summary = sess.run(merged, {self.X: x_batch, self.Y_clear: y_batch, self.learning_rate: l_rate, self.batch_size: b_size})
                        writer.add_summary(summary, i+ epoch_i*train_steps)
                if self.verbose:
                    print("\nSave model to {}".format(output_model))
                saver.save(sess, output_model, global_step=(epoch_i+1)*train_steps)
        except KeyboardInterrupt:
            saver.save(sess, output_model)
Exemplo n.º 4
0
Arquivo: ssl.py Projeto: huggyb/pypki
def main():
    parser = argparse.ArgumentParser(description="Command line tool to manage a public key infrastructure.")
    parser.add_argument('-c', '--create', help="Creates client certificate and keys.", action='store_true',
                        default=False)
    parser.add_argument('-d', '--default-ca', help="Creates the root CA with default parameters.", action='store_true',
                        default=False)
    parser.add_argument('-g', '--generate-ca', help="Generates the root CA and ask for values.", action='store_true',
                        default=False)
    parser.add_argument('-l', '--list', help="List all certificates and keys that are present in the pki.",
                        action='store_true', default=False)
    parser.add_argument('-r', '--default-req', help="Creates client certificate and keys with default attributes.", 
                        action='store_true', default=False)
    args = parser.parse_args()

    check_and_create_dir(PKI_DIR)

    if args.generate_ca or args.default_ca:
        if exists_and_isfile(CA_CERT_FULLPATH) and exists_and_isfile(CA_KEY_FULLPATH):
            print('There is an existing CA')
            return
        if args.default_ca:
            default_attrs = ['PL', 'Poneyland', 'kichland', 'Poney Corp', 'ROOT-CA', 'ROOT-CA Poney CORP']
        else:
            default_attrs = get_attr()

        k = crypto.PKey()
        k.generate_key(crypto.TYPE_RSA, 2048)

        my_cert = cacert_req(default_attrs)
        my_cert.set_pubkey(k)
        my_cert.sign(k, 'sha256')

        with open(CA_CERT_FULLPATH, 'wb') as fd:
            fd.write(crypto.dump_certificate(crypto.FILETYPE_PEM, my_cert))
        with open(CA_KEY_FULLPATH, 'wb') as fd:
            fd.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k))
    elif args.create or args.default_req:
        if args.default_req:
            cli_attrs = ['PL', 'Poneyland', 'kichland', 'Poney Corp', 'info', 'poney_test']
        else:
            cli_attrs = get_attr()

        # load ROOT CA private key
        ca_key_fd = open(CA_KEY_FULLPATH, 'rb').read()
        ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key_fd )
        ca_cert_fd = open(CA_CERT_FULLPATH, 'rb').read()
        ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, ca_cert_fd )
        # create csr 
        cli_pkey = crypto.PKey()
        cli_pkey.generate_key(crypto.TYPE_RSA, 2048)
        cli_csr = csr_req( cli_attrs )
        cli_csr.set_pubkey(cli_pkey)
        cli_csr.sign( cli_pkey, 'sha256')

        cli_cert = crypto.X509()
        cli_cert.set_issuer(ca_cert.get_subject())
        cli_cert.gmtime_adj_notBefore(0)
        cli_cert.gmtime_adj_notAfter(315360000)
        cli_cert.set_serial_number(cli_csr.get_serial_number())
        cli_cert.set_subject(cli_csr.get_subject())
        cli_cert.set_pubkey(cli_csr.get_pubkey())
        cli_cert.sign( ca_key, 'sha256')
        with open('key/client.crt', 'wb') as fd:
            fd.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cli_cert))
        with open('key/client.key', 'wb') as fd:
            fd.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, cli_pkey))
Exemplo n.º 5
0
    freeze_support()

    try:
        if len(sys.argv) > 1:
            all_config_filename = sys.argv[1]
            with open(all_config_filename, "r", encoding="UTF-8") as f:
                all_config = json.load(f)
        else:
            with open("config.json", "r", encoding="UTF-8") as f:
                all_config = json.load(f)
    except Exception as e:
        print("解析配置文件时出现错误,请检查配置文件!")
        print("错误详情:" + str(e))
        os.system('pause')

    utils.check_and_create_dir(all_config['root']['logger']['log_path'])
    logging.basicConfig(
        level=utils.get_log_level(all_config),
        format=
        '%(asctime)s %(thread)d %(threadName)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
        datefmt='%a, %d %b %Y %H:%M:%S',
        handlers=[
            logging.FileHandler(os.path.join(
                all_config['root']['logger']['log_path'], "Main_" +
                datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') +
                '.log'),
                                "a",
                                encoding="utf-8")
        ])

    runner_list = []
Exemplo n.º 6
0
        utils.add_path("./ffmpeg/bin")

    try:
        if len(sys.argv) > 1:
            all_config_filename = sys.argv[1]
            with open(all_config_filename, "r", encoding="UTF-8") as f:
                all_config = json.load(f)
        else:
            with open("config.json", "r", encoding="UTF-8") as f:
                all_config = json.load(f)
    except Exception as e:
        print("解析配置文件时出现错误,请检查配置文件!")
        print("错误详情:" + str(e))
        os.system('pause')

    utils.check_and_create_dir(
        all_config.get('root', {}).get('data_path', "./"))
    utils.check_and_create_dir(
        all_config.get('root', {}).get('logger', {}).get('log_path', './log'))
    logfile_name = "Main_" + datetime.datetime.now().strftime(
        '%Y-%m-%d_%H-%M-%S') + '.log'
    logging.basicConfig(
        level=utils.get_log_level(all_config),
        format=
        '%(asctime)s %(thread)d %(threadName)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
        datefmt='%a, %d %b %Y %H:%M:%S',
        handlers=[
            logging.FileHandler(os.path.join(
                all_config.get('root', {}).get('logger',
                                               {}).get('log_path', "./log"),
                logfile_name),
                                "a",
Exemplo n.º 7
0
    u = Uploader(p.outputs_dir, p.splits_dir, config)
    u.upload(p.global_start)


if __name__ == "__main__":
    root_config_filename = sys.argv[1]
    spec_config_filename = sys.argv[2]
    with open(root_config_filename, "r") as f:
        root_config = json.load(f)
    with open(spec_config_filename, "r") as f:
        spec_config = json.load(f)
    config = {
        'root': root_config,
        'spec': spec_config
    }
    utils.check_and_create_dir(config['root']['global_path']['data_path'])
    utils.check_and_create_dir(config['root']['logger']['log_path'])
    logging.basicConfig(level=utils.get_log_level(config),
                        format='%(asctime)s %(thread)d %(threadName)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                        datefmt='%a, %d %b %Y %H:%M:%S',
                        filename=os.path.join(config['root']['logger']['log_path'], datetime.datetime.now(
                        ).strftime('%Y-%m-%d_%H-%M-%S')+'.log'),
                        filemode='a')
    utils.init_data_dirs(config['root']['global_path']['data_path'])
    bl = BiliLive(config)
    prev_live_status = False
    while True:
        if not prev_live_status and bl.live_status:
            print("开播啦~")
            prev_live_status = bl.live_status
            start = datetime.datetime.now()