예제 #1
0
파일: Track.py 프로젝트: nbp/clouseau
    def __get_info(self):
        """Retrieve information
        """
        start_date = utils.get_date_str(self.date)
        if self.duration < 0:
            search_date = ['>=' + start_date]
        else:
            end_date = self.date + timedelta(self.duration)
            today = utils.get_date_ymd('today')
            if end_date > today:
                search_date = ['>=' + start_date]
            else:
                search_date = [
                    '>=' + start_date, '<' + utils.get_date_str(end_date)
                ]

        nb_hits = []
        socorro.SuperSearch(
            params={
                'product': self.product,
                'signature': '=' + self.signature,
                'date': search_date,
                'release_channel': self.channel,
                '_results_number': 0
            },
            handler=lambda json: nb_hits.append(json['total'])).wait()

        if nb_hits[0] > 1000:
            nb_hits[0] = 1000

        self.search = socorro.SuperSearch(params={
            'product':
            self.product,
            'signature':
            '=' + self.signature,
            'date':
            search_date,
            'release_channel':
            self.channel,
            '_sort':
            'build_id',
            '_columns': ['uuid', 'topmost_filenames'],
            '_facets': [
                'platform_pretty_version', 'build_id', 'version',
                'system_memory_use_percentage', 'cpu_name', 'cpu_info',
                'reason', 'addons', 'uptime', 'url'
            ],
            '_facets_size':
            nb_hits[0],
            '_results_number':
            nb_hits[0]
        },
                                          handler=self.__handler)
예제 #2
0
파일: Track.py 프로젝트: lizzard/clouseau
 def __get_info(self):
     header = { 'Auth-Token': self.__get_apikey() }
     self.results.append(self.session.get(self.SUPERSEARCH_URL,
                                              params = { 'product': 'Firefox',
                                                         'signature': '=' + self.signature,
                                                         'date': ['>=' + utils.get_date_str(self.date),
                                                                  '<' + utils.get_date_str(self.date + timedelta(self.day_delta))],
                                                         'release_channel': 'nightly',
                                                         '_sort': 'build_id',
                                                         '_columns': ['uuid', 'topmost_filenames'],
                                                         '_facets': ['platform_pretty_version', 'build_id', 'version', 'release_channel', 'system_memory_use_percentage', 'addons'],
                                                         '_results_number': 100,
                                                             },
                                              headers = header,
                                              timeout = self.TIMEOUT,
                                              background_callback = self.__info_cb))
예제 #3
0
def get_feature_path(exp_name, dataset, ckpt_folder, nbit=32, date_str=None):
    if date_str is None:
        exp_path = "../../results/features/{}_{}".format(
            exp_name, utils.get_date_str())
    else:
        exp_path = "../../results/features/{}_{}".format(exp_name, date_str)
    ckpt_name = ckpt_folder.split("/")[-1]
    folder = exp_path + "/{}/nbit_{}/{}".format(dataset, nbit, ckpt_name)
    return folder
예제 #4
0
def main(argv=None):
    if not FLAGS.save_name:
        save_dir = utils.get_date_str()
    else:
        save_dir = FLAGS.save_name

    save_locations = tf_easy_dir.tf_easy_dir(save_dir=save_dir)
    if FLAGS.rewrite:
        save_locations.clear_save_name()
예제 #5
0
파일: Track.py 프로젝트: La0/clouseau
    def __get_info(self):
        """Retrieve information
        """
        start_date = utils.get_date_str(self.date)
        if self.duration < 0:
            search_date = ['>=' + start_date]
        else:
            end_date = self.date + timedelta(self.duration)
            today = utils.get_date_ymd('today')
            if end_date > today:
                search_date = ['>=' + start_date]
            else:
                search_date = ['>=' + start_date, '<' + utils.get_date_str(end_date)]

        nb_hits = []
        socorro.SuperSearch(params={'product': self.product,
                                    'signature': '=' + self.signature,
                                    'date': search_date,
                                    'release_channel': self.channel,
                                    '_results_number': 0},
                            handler=lambda json: nb_hits.append(json['total'])).wait()

        if nb_hits[0] > 1000:
            nb_hits[0] = 1000

        self.search = socorro.SuperSearch(params={'product': self.product,
                                                  'signature': '=' + self.signature,
                                                  'date': search_date,
                                                  'release_channel': self.channel,
                                                  '_sort': 'build_id',
                                                  '_columns': ['uuid', 'topmost_filenames'],
                                                  '_facets': ['platform_pretty_version',
                                                              'build_id',
                                                              'version',
                                                              'system_memory_use_percentage',
                                                              'cpu_name',
                                                              'cpu_info',
                                                              'reason',
                                                              'addons',
                                                              'uptime',
                                                              'url'],
                                                  '_facets_size': nb_hits[0],
                                                  '_results_number': nb_hits[0]},
                                          handler=self.__handler)
예제 #6
0
def main(argv=None):
    if not FLAGS.save_name:
        save_dir = os.path.join('Save', utils.get_date_str())
    else:
        save_dir = os.path.join('Save', FLAGS.save_name)

    save_locations = tf_easy_dir.tf_easy_dir(save_dir=save_dir)
    if FLAGS.rewrite:
        save_locations.clear_save_name()

    train(save_locations)
예제 #7
0
def get_pred_path_from_feature_path(exp_name,
                                    dataset,
                                    feat_folder,
                                    nbit=32,
                                    date_str=None):
    if date_str is None:
        exp_path = "../../results/predictions/{}_{}".format(
            exp_name, utils.get_date_str())
    else:
        exp_path = "../../results/predictions/{}_{}".format(exp_name, date_str)
    feat_name = feat_folder.split("/")[-1]
    folder = exp_path + "/{}/nbit_{}/{}".format(dataset, nbit, feat_name)
    return os.path.abspath(folder)
예제 #8
0
def loop_recompute(argo_exec, date_range, tenant, job, log):
    """
    For a specific time period, loop and execute recomputations for each day

    :param argo_exec: path to argo bin directory
    :param date_range: list with all available dates included in the period
    :param tenant: tenant name
    :param job_set: list of tenant's available jobs
    :param log: logger reference
    """
    for dt in date_range:
        date_arg = get_date_str(dt)
        do_recompute(argo_exec, date_arg, tenant, job, log)
예제 #9
0
def generate_rtx_cardinfo(jokes_json, ori_url, today=True):
    if None == jokes_json:
        print("jokes_json is None")
        return None
    articles = []
    if today:
        articles.append({
            'title': '如何正确地吐槽',
            'url': ori_url,
            'picurl': jokes_json['cover'],
            'description': '今天是 ' + utils.get_date_str()
        })
    else:
        articles.append({
            'title': '往期沙雕',
            'url': ori_url,
            'picurl': jokes_json['cover']
        })
    data = {}
    data['msgtype'] = 'news'
    data['news'] = {}
    data['news']['articles'] = articles
    print(data)
    return data
def main(args=None):

    # default paths
    fn_ar_cfg = "/etc/ar-compute-engine.conf"
    arcomp_conf = "/etc/ar-compute/"
    arcomp_exec = "/usr/libexec/ar-compute/"
    stdl_exec = "/usr/libexec/ar-compute/bin"
    pig_script_path = "/usr/libexec/ar-compute/pig/"

    one_day_ago = utils.get_actual_date(args.date) - timedelta(days=1)
    prev_date = utils.get_date_str(one_day_ago)
    prev_date_under = utils.get_date_under(prev_date)
    date_under = utils.get_date_under(args.date)

    # Init configuration
    cfg = utils.ArgoConfiguration(fn_ar_cfg)
    cfg.load_tenant_db_conf(os.path.join(arcomp_conf, args.tenant + "_db_conf.json"))
    # Init logging
    log = init_log(cfg.log_mode, cfg.log_file, cfg.log_level, 'argo.job_status_detail')

    local_cfg_path = arcomp_conf
    # open job configuration file
    json_cfg_file = open(
        local_cfg_path + args.tenant + "_" + args.job + "_cfg.json")
    json_cfg = json.load(json_cfg_file)

    # Inform the user in wether argo runs locally or distributed
    if cfg.mode == 'local':
        log.info("ARGO compute engine runs in LOCAL mode")
        log.info("computation job will be run locally")
    else:
        log.info("ARGO compute engine runs in CLUSTER mode")
        log.info("computation job will be submitted to the hadoop cluster")

    # Proposed hdfs pathways
    hdfs_mdata_path = './' + args.tenant + "/mdata/"
    hdfs_sync_path = './scratch/sync/' + args.tenant + \
        "/" + args.job + "/" + date_under + "/"

    # Proposed local pathways
    local_mdata_path = '/tmp/' + args.tenant + "/mdata/"
    local_sync_path = '/tmp/scratch/sync/' + args.tenant + \
        '/' + args.job + '/' + date_under + '/'
    local_cfg_path = arcomp_conf

    if cfg.mode == 'cluster':
        mode = 'cache'
        mdata_path = hdfs_mdata_path
        sync_path = hdfs_sync_path
        cfg_path = hdfs_sync_path

    else:
        mode = 'local'
        mdata_path = local_mdata_path
        sync_path = local_sync_path
        cfg_path = local_cfg_path

    # dictionary with necessary pig parameters
    pig_params = {}

    pig_params['mdata'] = mdata_path + 'prefilter_' + date_under + '.avro'
    pig_params['p_mdata'] = mdata_path + \
        'prefilter_' + prev_date_under + '.avro'
    pig_params['egs'] = sync_path + 'group_endpoints.avro'
    pig_params['ggs'] = sync_path + 'group_groups.avro'
    pig_params['mps'] = sync_path + 'poem_sync.avro'
    pig_params['cfg'] = cfg_path + args.tenant + '_' + args.job + '_cfg.json'
    pig_params['aps'] = cfg_path + args.tenant + '_' + args.job + '_ap.json'
    pig_params['rec'] = cfg_path + 'recomputations_' + args.tenant + '_' + date_under + '.json'
    pig_params['ops'] = cfg_path + args.tenant + '_ops.json'
    pig_params['dt'] = args.date
    pig_params['mode'] = mode
    pig_params['flt'] = '1'
    pig_params['mongo_status_metrics'] = cfg.get_mongo_uri('status', 'status_metrics')
    pig_params['mongo_status_endpoints'] = cfg.get_mongo_uri('status', 'status_endpoints')
    pig_params['mongo_status_services'] = cfg.get_mongo_uri('status', 'status_services')
    pig_params['mongo_status_endpoint_groups'] = cfg.get_mongo_uri('status', 'status_endpoint_groups')
    cmd_pig = []

    # Append pig command
    cmd_pig.append('pig')

    # Append Pig local execution mode flag
    if cfg.mode == "local":
        cmd_pig.append('-x')
        cmd_pig.append('local')

    # Append Pig Parameters
    for item in pig_params:
        cmd_pig.append('-param')
        cmd_pig.append(item + '=' + pig_params[item])

    # Append Pig Executionable Script
    cmd_pig.append('-f')
    cmd_pig.append(pig_script_path + 'compute-status.pig')

    # Command to clean a/r data from mongo
    cmd_clean_mongo_status = [
        os.path.join(stdl_exec, "mongo_clean_status.py"), '-d', args.date, '-t', args.tenant, '-r', json_cfg['id']]

    # Command to upload sync data to hdfs
    cmd_upload_sync = [os.path.join(
        stdl_exec, "upload_sync.py"), '-d', args.date, '-t', args.tenant, '-j', args.job]

    # Command to clean hdfs data
    cmd_clean_sync = ['hadoop', 'fs', '-rm', '-r', '-f', hdfs_sync_path]

    # Upload data to hdfs
    log.info("Uploading sync data to hdfs...")
    run_cmd(cmd_upload_sync, log)

    # Clean data from mongo
    log.info("Cleaning data from mongodb")
    run_cmd(cmd_clean_mongo_status, log)

    # Call pig
    log.info("Submitting pig compute status detail job...")
    run_cmd(cmd_pig, log)

    # Cleaning hdfs sync data
    if cfg.sync_clean == "true":
        log.info("System configured to clean sync hdfs data after job")
        run_cmd(cmd_clean_sync, log)

    log.info("Execution of status job for tenant %s for date %s completed!",
             args.tenant, args.date)
예제 #11
0
def get_folder_id(token):
    folder_name = get_date_str()
    folder_id = folder_exist(token, folder_name)
    if folder_id:
        return folder_id, folder_name
    return add_folder(token, folder_name)
예제 #12
0
    "Green laser cutter 32x18 in": "Green"
})

# group by cutters, then by start times within each cutter grouping
df = df.sort_values(by=["Cutter", "Start"])

# create dictionaries. Consider using default dicts later
# {string date: pandas.Timedelta total_time}
reds = {}
blues = {}
greens = {}
yellows = {}

for index, cutter, start, stop in df.itertuples():
    print(start.day_name())
    date = utils.get_date_str(start)
    # Initialize dictionary with default time deltas of 0
    if date not in reds or date not in blues or date not in greens or date not in yellows:
        reds[date] = pandas.Timedelta(0)
        blues[date] = pandas.Timedelta(0)
        greens[date] = pandas.Timedelta(0)
        yellows[date] = pandas.Timedelta(0)

    # add time deltas to the respective dicts
    if cutter == "Red":
        reds[date] += stop - start
    elif cutter == "Blue":
        blues[date] += stop - start
    elif cutter == "Green":
        greens[date] += stop - start
    elif cutter == "Yellow":
예제 #13
0
def main():

    args = parse_args()

    ### SETTINGS #############################
    n = 100000  # num images
    frame_size = (64, 64)
    patch_size = 18

    # count_distrib = {1: 1}
    count_distrib = {0: 1 / 3, 1: 1 / 3, 2: 1 / 3}
    allow_overlap = True
    ##########################################

    # Generate sprites and labels
    print("generating sprites...")
    if args.dataset_type == 'dsprites':
        sprites, labels = generate_dsprites(patch_size)
    elif args.dataset_type == 'binary_mnist':
        sprites, labels = generate_binary_mnist(patch_size)
    else:
        raise NotImplementedError

    # Show sprites
    show_img_grid(8,
                  sprites,
                  random_selection=True,
                  fname='gen_{}_sprites.png'.format(get_date_str()))

    # Create dataset
    print("generating dataset...")
    ch = sprites[0].shape[-1]
    img_shape = (*frame_size, ch)
    dataset, n_obj, labels = generate_multiobject_dataset(
        n,
        img_shape,
        sprites,
        labels,
        count_distrib=count_distrib,
        allow_overlap=allow_overlap)
    print("done")
    print("shape:", dataset.shape)

    # Number of objects is part of the labels
    labels['n_obj'] = n_obj

    # Save dataset
    print("saving...")
    root = os.path.join('generated', args.dataset_type)
    os.makedirs(root, exist_ok=True)
    file_str = get_date_str()
    fname = 'multi_' + args.dataset_type + '_' + file_str
    fname = os.path.join(root, fname)
    np.savez_compressed(fname, x=dataset, labels=labels)
    print('done')

    # Show samples and print their attributes
    print("\nAttributes of saved samples:")
    show_img_grid(4,
                  dataset,
                  labels,
                  fname='gen_{}_images.png'.format(get_date_str()))

    # Show distribution of number of objects per image
    plt.figure()
    plt.hist(n_obj, np.arange(min(n_obj) - 0.5, max(n_obj) + 0.5 + 1, 1))
    plt.title("Distribution of num objects per image")
    plt.xlabel("Number of objects")
    plt.savefig('gen_{}_distribution.png'.format(get_date_str()))
예제 #14
0
def main():

    # user input
    if not g_args.type:
        g_args.type = utils.read_string_input(msg="type",
                                              init_value=DEFAULT_TEMPLATE)
    if not g_args.output:
        g_args.output = utils.read_path_input(msg="file path")
    g_args.author = utils.read_string_input(msg="author",
                                            init_value=g_args.author)
    if not g_args.description:
        g_args.description = utils.read_string_input(
            msg="description", init_value=HEADER_DESCRIPTION)
    g_args.description = format_description(g_args.description)
    if g_args.type != "bash":
        g_args.copy_utils = utils.confirm("Copy utilities?",
                                          "y" if g_args.copy_utils else "n")

    template_path = utils.join_paths_str(g_args.script_dir,
                                         TEMPLATE_OPTIONS[g_args.type])

    if utils.exists_dir(g_args.output):
        print(f"Error: target path is a directory!")
        sys.exit(0)
    elif utils.exists_file(g_args.output):
        if not utils.confirm_delete_file(g_args.output, "n"):
            utils.exit("Aborted")

    out_folder = utils.get_file_path(g_args.output)
    out_file = utils.get_file_name(g_args.output)

    if not utils.exists_dir(out_folder):
        utils.make_dir(g_args.output)

    # copy template
    utils.copy_to(template_path, g_args.output)
    print(f"Created file {g_args.output}")

    if g_args.type == "class":
        utils.replace_file_text(g_args.output, CLASS_NAME, out_file)

    if g_args.type != "bash":
        if g_args.copy_utils:
            utils_folder = PY_UTILS_DIR
            out_py_utils_dir = utils.join_paths(out_folder, utils_folder)
            utils.make_dir(out_py_utils_dir)
            utils.copy_to(PY_UTILS_FILE, out_py_utils_dir)
            print(f"Created file {out_py_utils_dir}/{PY_UTILS_FILE}")
        else:
            print("""
            Important: Please make sure that python utils are available, i.e. inside PYTHONPATH.
            Clone repository via: git clone https://github.com/amplejoe/py_utils.git
            """)

    # header information
    date = utils.get_date_str()
    utils.replace_file_text(g_args.output, HEADER_DATE, date)
    utils.replace_file_text(g_args.output, HEADER_AUTHOR, g_args.author)
    if g_args.description:
        utils.replace_file_text(g_args.output, HEADER_DESCRIPTION,
                                g_args.description)
예제 #15
0
def train():
    # import cifar10_simple_baseline as cifar10_model

    cifar10_model = importlib.import_module(FLAGS.architecture)
    if not FLAGS.save_name:
        save_dir = os.path.join('Save',
                                utils.get_date_str() + FLAGS.architecture)
    else:
        save_dir = os.path.join('Save', FLAGS.save_name)

    save_locations = tf_easy_dir.tf_easy_dir(save_dir=save_dir)
    if FLAGS.rewrite:
        save_locations.clear_save_name()

    with tf.Graph().as_default() as graph:
        global_step = tf.get_variable(name='gstep',
                                      initializer=tf.constant(0),
                                      trainable=False)
        [batch_images, batch_labels] = cifar10_inputs.inputs(FLAGS.data_dir,
                                                             FLAGS.batch_size,
                                                             isTraining=True,
                                                             isRandom=True)
        print 'size of image input: [{:s}]'.format(', '.join(
            map(str,
                batch_images.get_shape().as_list())))
        print 'size of labels : [{:s}]'.format(', '.join(
            map(str,
                batch_labels.get_shape().as_list())))
        print '-' * 32
        sys.stdout.flush()

        logits = cifar10_model.inference(batch_images, isTraining=True)
        loss = cifar10_model.loss(logits=logits, labels=batch_labels)
        train_op, lr = cifar10_model.train(loss, global_step)
        # update_ops = tf.group(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
        # with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        #     update_ops = tf.no_op()

        correct_ones = cifar10_model.correct_ones(logits=logits,
                                                  labels=batch_labels)

        # # debug:
        # with tf.variable_scope("", reuse=True):
        #     moving_mean1 = tf.get_variable('inference/Conv/BatchNorm/moving_mean')
        #     moving_variance1 = tf.get_variable('inference/Conv/BatchNorm/moving_variance')

        saver = tf.train.Saver(max_to_keep=None)
        summary_op = tf.summary.merge_all()

        config = tf_utils_inner.gpu_config(FLAGS.gpu_id)
        with tf.Session(config=config) as sess:
            sess.run(tf.variables_initializer(tf.global_variables()))
            summary_writer = tf.summary.FileWriter(
                logdir=save_locations.summary_save_dir, graph=sess.graph)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            for i in range(FLAGS.max_steps):
                _, loss_, correct_ones_, lr_ = sess.run(
                    [train_op, loss, correct_ones, lr])

                assert not np.isnan(
                    loss_), 'Model diverged with loss = NaN, try again'

                # if (i+1) % 10 == 0:
                print '[{:s} -- {:08d}|{:08d}]\tloss : {:.3f}\t, l-rate: {:.6f}\tcorrect ones [{:d}|{:d}]'.format(
                    save_dir, i, FLAGS.max_steps, loss_, lr_, correct_ones_,
                    FLAGS.batch_size)
                sys.stdout.flush()

                # mm1, mv1 = sess.run([moving_mean1, moving_variance1])
                # print 'Sum of moving mean: {:.6f} \t, moving variance: {:.06f}'.format(np.sum(mm1), np.sum(mv1))

                if (i + 1 % 100) == 0:
                    summary_ = sess.run(summary_op)
                    summary_writer.add_summary(summary_,
                                               global_step=global_step)

                if (i + 1) % 2000 == 0:
                    save_path = os.path.join(save_locations.model_save_dir,
                                             'model')
                    saver.save(sess=sess,
                               global_step=global_step,
                               save_path=save_path)
            coord.request_stop()
            coord.join(threads=threads)
예제 #16
0
        print("spider failed")
    myjokes = utils.get_qa_from_html(myhtml)
    if None == myjokes:
        print("parse failed")
    rtx_md = push.generate_rtx_markdown(myjokes)
    rtx_card = push.generate_rtx_cardinfo(myjokes, ori_url)
    push.push_to_rtx(rtx_card)
    push.push_to_rtx(rtx_md)


def timer_func():
    global pushed
    if utils.check_if_is_time(7, 30) and not pushed:
        pushed = True
        do_push()
        import os
        os._exit(0)
        return True
    return False


def timer_test_func():
    print('timer')


if __name__ == "__main__":
    while True:
        print(utils.get_date_str() + " " + utils.get_time_str())
        timer_func()
        time.sleep(30)
예제 #17
0
파일: manager.py 프로젝트: michaelvsj/sm
 def get_new_capture_folder(self):
     rel_dir = self.sys_id + os.sep + get_date_str(
     ) + os.sep + self.session + os.sep + f"{self.segment:04d}"
     folder = os.path.join(self.capture_dir_base, rel_dir)
     os.makedirs(folder, exist_ok=True)
     return folder
예제 #18
0
def main(argv=None):

    if not FLAGS.save_name:
        FLAGS.save_name = os.path.join('c3dSave', utils.get_date_str())
    else:
        FLAGS.save_name = os.path.join('c3dSave', FLAGS.save_name)

    # print parameters
    tf_utils.print_gflags(FLAGS=FLAGS)

    NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = len(
        glob.glob(
            os.path.join(FLAGS.data_dir,
                         '*.{:s}'.format(input_reader.TF_FORMAT))))
    if NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN < 1:
        print "Check file path"
        return

    steps_per_epoch = int(
        math.ceil(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 1.0 / FLAGS.batch_size))
    print 'Each epoch consists {:d} Steps'.format(steps_per_epoch)
    sys.stdout.flush()

    lr_decay_every_n_step = int(steps_per_epoch * FLAGS.num_epoch_per_decay)

    # save_dir = FLAGS.save_name
    save_locations = tf_easy_dir.tf_easy_dir(save_dir=FLAGS.save_name)

    if FLAGS.rewrite:
        save_locations.clear_save_name()

    with tf.Graph().as_default() as graph:
        global_step = tf.get_variable(name='gstep',
                                      initializer=tf.constant(0),
                                      trainable=False)
        batch_images, batch_labels, batch_filenames = input_reader.inputs(
            FLAGS.data_dir, isTraining=True)
        print 'size of image input: [{:s}]'.format(', '.join(
            map(str,
                batch_images.get_shape().as_list())))
        print 'size of labels : [{:s}]'.format(', '.join(
            map(str,
                batch_labels.get_shape().as_list())))
        print '-' * 32
        sys.stdout.flush()

        logits = c3d_model.inference_c3d(batch_images, isTraining=True)
        loss = c3d_model.loss(logits=logits, labels=batch_labels)

        train_op, lr = c3d_model.train(loss, global_step,
                                       lr_decay_every_n_step)

        correct_ones = c3d_model.correct_ones(logits=logits,
                                              labels=batch_labels)

        saver = tf.train.Saver(max_to_keep=None)
        summary_op = tf.summary.merge_all()

        config = tf_utils.gpu_config(FLAGS.gpu_id)
        with tf.Session(config=config) as sess:
            sess.run(tf.variables_initializer(tf.global_variables()))
            summary_writer = tf.summary.FileWriter(
                logdir=save_locations.summary_save_dir, graph=sess.graph)

            coord = tf.train.Coordinator()

            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            print 'Training Start!'
            sys.stdout.flush()

            cum_loss = 0
            cum_correct = 0

            for i in range(FLAGS.max_steps):
                _, loss_, correct_ones_ = sess.run(
                    [train_op, loss, correct_ones])

                assert not np.isnan(
                    loss_), 'Model diverged with loss = NaN, try again'

                cum_loss += loss_
                cum_correct += correct_ones_
                # update: print loss every epoch
                if (i + 1) % steps_per_epoch == 0:
                    lr_ = sess.run(lr)

                    print '[{:s} -- {:08d}|{:08d}]\tloss : {:.3f}\t, correct ones [{:d}|{:d}], l-rate:{:.06f}'.format(
                        save_dir, i, FLAGS.max_steps,
                        cum_loss / steps_per_epoch, cum_correct,
                        NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN, lr_)
                    sys.stdout.flush()
                    cum_loss = 0
                    cum_correct = 0

                if (i + 1 % 100) == 0:
                    summary_ = sess.run(summary_op)
                    summary_writer.add_summary(summary_,
                                               global_step=global_step)

                if (i + 1) % 2000 == 0:
                    save_path = os.path.join(save_locations.model_save_dir,
                                             'model')
                    saver.save(sess=sess,
                               global_step=global_step,
                               save_path=save_path)
            coord.request_stop()
            coord.join(threads=threads)