Esempio n. 1
0
def get_init_values_for_pruned_layers(prune_scopes, shorten_scopes,
                                      kept_percentage, prune_info):
    """ prune layers iteratively so prune_scopes and shorten scopes should be of size one. """
    graph = tf.Graph()
    with graph.as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                              FLAGS.dataset_dir)
        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        images, _ = batch_queue.dequeue()

        # set the scope
        scope = FLAGS.net_name_scope_pruned
        if FLAGS.stage_index == 0:
            scope = FLAGS.net_name_scope_checkpoint

        network_fn = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            prune_info=prune_info)
        network_fn(images, is_training=False, scope=scope)

        with tf.Session() as sess:
            load_checkpoint(sess, FLAGS.checkpoint_path)
            print("get pruned kernel matrix with kept_percentage",
                  kept_percentage)
            variables_init_value = get_pruned_kernel_matrix(
                sess, prune_scopes, shorten_scopes, kept_percentage)
    # remove graph
    del graph
    return variables_init_value
Esempio n. 2
0
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    # init
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    if FLAGS.stage_index:
        net_name_scope_checkpoint = net_name_scope_pruned

    # choose the pruning option for the stage
    kp = choose_kp()
    pre_config = load_checkpoint_config()
    print('Load previous config:', len(pre_config), pre_config)

    fake_config, cur_config, new_prune_scopes, new_kept_percentage = generate_config(
        pre_config, kp)
    print('Generate fake_config:', fake_config)
    print('Generate config:', len(cur_config), cur_config)
    print('new prune scopes:', len(new_prune_scopes), new_prune_scopes)

    pre_prune_scopes, pre_kept_percentage = [], []
    for i, kp in enumerate(pre_config):
        if kp == 1.0:
            continue
        pre_kept_percentage.extend([kp, kp])
        pre_prune_scopes.extend(valid_indexed_prune_scopes_for_units[i])

    cur_prune_scopes, cur_kept_percentage = [], []
    for i, kp in enumerate(cur_config):
        if kp == 1.0:
            continue
        cur_kept_percentage.extend([kp, kp])
        cur_prune_scopes.extend(valid_indexed_prune_scopes_for_units[i])
    print("cur_prune_scopes", len(cur_prune_scopes), cur_prune_scopes)
    print("cur_kept_percentage", len(cur_kept_percentage), cur_kept_percentage)

    # prepare for training with the specific config
    pre_prune_info = indexed_prune_scopes_to_prune_info(
        pre_prune_scopes, pre_kept_percentage)
    cur_prune_info = indexed_prune_scopes_to_prune_info(
        cur_prune_scopes, cur_kept_percentage)

    print("pre_prune_info")
    pprint(pre_prune_info)
    print("cur_prune_info")
    pprint(cur_prune_info)

    # prepare file system
    results_dir = os.path.join(FLAGS.train_dir, "_".join(map(
        str, fake_config)))  #+'_'+str(FLAGS.max_number_of_steps))
    train_dir = os.path.join(results_dir,
                             'train_lr' + str(FLAGS.learning_rate))
    print('train_dir:', train_dir)

    if (not FLAGS.continue_training) or (
            not tf.train.latest_checkpoint(train_dir)):
        prune_scopes = indexed_prune_scopes_to_prune_scopes(
            new_prune_scopes, net_name_scope_checkpoint)
        shorten_scopes = indexed_prune_scopes_to_shorten_scopes(
            new_prune_scopes, net_name_scope_checkpoint)
        print("prune_scopes", len(prune_scopes),
              "\n\t" + "\n\t".join(prune_scopes))
        print("shorten_scopes", len(shorten_scopes),
              "\n\t" + "\n\t".join(shorten_scopes))
        # exit()

        variables_init_value = get_init_values_for_pruned_layers(
            prune_scopes, shorten_scopes, new_kept_percentage, pre_prune_info)
        # print("variables_init_value:", len(variables_init_value.keys()), variables_init_value.keys())

        reinit_scopes = [
            re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v)
            for v in prune_scopes + shorten_scopes
        ]

        prepare_file_system(train_dir)

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir + '\n'
    info += 'learning_rate:' + str(FLAGS.learning_rate) + '\n'
    info += 'stage_index:' + str(FLAGS.stage_index) + '\n'
    info += 'configuration: ' + str(cur_config) + '\n'
    info += 'cur_prune_scopes: ' + str(cur_prune_scopes) + '\n'
    info += 'kept_percentage: ' + str(cur_kept_percentage)
    print(info)
    write_detailed_info(info)

    def write_log(info):
        with open(os.path.join(results_dir, 'log.txt'), 'a') as f:
            f.write(info)
            f.write('\n')

    # write current config into file
    with open(os.path.join(train_dir, "config.txt"), 'w') as f:
        f.write(", ".join(map(str, cur_config)))

    with tf.Graph().as_default():

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            prune_info=cur_prune_info,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        ####################
        # Define the model #
        ####################
        logits_train, _ = network_fn_pruned(images,
                                            is_training=True,
                                            is_local_train=False,
                                            reuse_variables=False,
                                            scope=net_name_scope_pruned)
        logits_eval, _ = network_fn_pruned(test_images,
                                           is_training=False,
                                           is_local_train=False,
                                           reuse_variables=True,
                                           scope=net_name_scope_pruned)

        cross_entropy = add_cross_entropy(logits_train, labels)
        correct_prediction = add_correct_prediction(logits_eval, test_labels)

        #############################
        # Specify the loss function #
        #############################
        tf.add_to_collection('subgraph_losses', cross_entropy)
        # get regularization loss
        regularization_losses = get_regularization_losses_within_scopes()
        print_list('regularization_losses', regularization_losses)

        # total loss and its summary
        total_loss = tf.add_n(tf.get_collection('subgraph_losses'),
                              name='total_loss')
        for l in tf.get_collection('subgraph_losses') + [total_loss]:
            tf.summary.scalar(l.op.name + '/summary', l)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        #############################
        # Add train operation       #
        #############################
        variables_to_train = get_trainable_variables_within_scopes()
        train_op = add_train_op(optimizer,
                                total_loss,
                                global_step,
                                var_list=variables_to_train)
        print_list("variables_to_train", variables_to_train)

        # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned.
        update_ops = get_update_ops_within_scopes()
        print_list("update_ops", update_ops)

        # add train_tensor
        update_ops.append(train_op)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # add summary op
        summary_op = tf.summary.merge_all()

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)
        with tf.Session(config=sess_config) as sess:
            ###########################
            # Prepare for filewriter. #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            # if restart the training or there is no checkpoint in the train_dir
            if (not FLAGS.continue_training) or (
                    not tf.train.latest_checkpoint(train_dir)):
                #########################################
                # Reinit  pruned model variable  #
                #########################################
                variables_to_reinit = get_model_variables_within_scopes(
                    reinit_scopes)
                print_list("Initialize pruned variables", variables_to_reinit)
                assign_ops = []
                for v in variables_to_reinit:
                    key = re.sub(net_name_scope_pruned,
                                 net_name_scope_checkpoint, v.op.name)
                    if key in variables_init_value:
                        value = variables_init_value.get(key)
                        # print(key, value)
                        assign_ops.append(
                            tf.assign(v,
                                      tf.convert_to_tensor(value),
                                      validate_shape=True))
                        # v.set_shape(value.shape)
                    else:
                        raise ValueError(
                            "Key not in variables_init_value, key=", key)
                assign_op = tf.group(*assign_ops)
                sess.run(assign_op)

                #################################################
                # Restore unchanged model variable. #
                #################################################
                variables_to_restore = {
                    re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                           v.op.name): v
                    for v in get_model_variables_within_scopes()
                    if v not in variables_to_reinit
                }
                print_list("restore model variables",
                           variables_to_restore.values())
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

            else:
                ###########################################
                ## Restore all variables from checkpoint ##
                ###########################################
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################
            variables_to_init = get_global_variables_within_scopes(
                sess.run(tf.report_uninitialized_variables()))
            print_list("init unitialized variables", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('\nExit: init_global_step_value (%d) >= FLAG.max_number_of_steps (%d)\n\n' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
            print('HG: # of threads=', len(threads))

            duration = 0
            duration_cnt = 0
            train_time = 0
            train_only_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):
                # run optional meta data, or summary, while run train tensor
                #if i < FLAGS.max_number_of_steps:
                if i > init_global_step_value:
                    # train while run metadata
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_value = sess.run(train_tensor,
                                              options=run_options,
                                              run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # train while record summary
                    elif i % FLAGS.summary_every_n_steps == 0:
                        train_summary, loss_value = sess.run(
                            [summary_op, train_tensor])
                        train_writer.add_summary(train_summary, i)

                    # train only
                    else:
                        start_time = time.time()
                        loss_value = sess.run(train_tensor)
                        train_only_cnt += 1
                        train_time += time.time() - start_time
                        duration_cnt += 1
                        duration += time.time() - start_time

                    # log loss information
                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, i)
                        format_str = (
                            '%s: step %d, loss = %.3f (%.1f examples/sec; %.3f sec/batch)'
                        )
                        print(format_str % (datetime.now(), i, loss_value,
                                            examples_per_sec, sec_per_batch))
                        duration = 0
                        duration_cnt = 0

                        info = format_str % (datetime.now(), i, loss_value,
                                             examples_per_sec, sec_per_batch)
                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    train_summary, loss_value = sess.run(
                        [summary_op,
                         total_loss])  #loss_value = sess.run(total_loss)
                    train_writer.add_summary(train_summary, i)
                    format_str = ('%s: step %d, loss = %.3f')
                    print(format_str % (datetime.now(), i, loss_value))
                    info = format_str % (datetime.now(), i, loss_value)
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:
                    #run_meta = (i==FLAGS.evaluate_every_n_steps)
                    test_accuracy, run_metadata = evaluate_accuracy(
                        sess,
                        coord,
                        test_dataset.num_samples,
                        test_images,
                        test_labels,
                        test_images,
                        test_labels,
                        correct_prediction,
                        FLAGS.test_batch_size,
                        run_meta=False)
                    summary = tf.Summary()
                    summary.value.add(tag='accuracy',
                                      simple_value=test_accuracy)
                    train_writer.add_summary(summary, i)
                    #if run_meta:
                    #    eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i)

                    info = ('%s: step %d, test_accuracy = %.6f') % (
                        datetime.now(), i, test_accuracy)
                    print(info)
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    #saver = tf.train.Saver(var_list=get_model_variables_within_scopes([net_name_scope_pruned+'/']))
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

                    # check accuracy
                    if test_accuracy >= FLAGS.baseline_accuracy - FLAGS.drop_rate:
                        print(
                            '\nStop training becuase accuracy threshold (%f) reached!\n\n'
                            % (FLAGS.baseline_accuracy - FLAGS.drop_rate))
                        write_detailed_info(
                            'Stop training becuase accuracy threshold (%f) reached!'
                            % (FLAGS.baseline_accuracy - FLAGS.drop_rate))
                        break

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic
            train_speed = train_time * 1.0 / train_only_cnt
            train_time = train_speed * (
                FLAGS.max_number_of_steps
            )  # - init_global_step_value) #/train_only_cnt
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f" % (
                train_time / 60.0, total_time / 60.0)
            print(info)
            write_detailed_info(info)
Esempio n. 3
0
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    # initialize constants
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    kp_options = sorted([float(x) for x in FLAGS.kept_percentages.split(',')])
    num_blocks = int(len(valid_layer_names) / FLAGS.block_size)
    if len(valid_layer_names) % FLAGS.block_size != 0:
        print('ERROR: len(valid_layer_names)%FLAGS.block_size!=0')
        return

    # if FLAGS.block_id >= num_blocks:
    #     print('ERROR: block_id=%d should be smaller than the number of blocks=%d' %(FLAGS.block_id, num_blocks))
    #     return
    print('HG: kp_options', kp_options)
    print('HG: block size:', FLAGS.block_size)
    print('HG: number of blocks:',
          num_blocks)  #, ', block_id:', FLAGS.block_id)
    if len(kp_options) > 1:
        print('ERROR: only support kp options = 1')
        return

    # prepare file system
    # block_config_str = '_'.join(map(str, block_config))
    if FLAGS.last_conv_pruned:
        foldername = 'last_conv_pruned'
    else:
        foldername = 'last_conv_unpruned'
    results_dir = os.path.join(FLAGS.train_dir, foldername,
                               'kp' + str(FLAGS.kept_percentages))
    train_dir = os.path.join(results_dir, 'train')
    print('HG: train_dir:', train_dir)
    if not (FLAGS.continue_training and tf.train.latest_checkpoint(train_dir)):
        prepare_file_system(train_dir)

    def write_log_info(info):
        with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f:
            f.write(info + '\n')

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir + '\n'
    info += 'kp_options:' + str(kp_options) + '\n'
    log_info = info + '\n'
    write_detailed_info(info)

    # set prune info
    prune_info = {}
    block_layer_names_list = []
    pruned_layer_names_list = []
    block_config_list = []

    prune_scopes = []
    shorten_scopes = []
    network_config = []

    for block_id in xrange(num_blocks):
        #------------------------
        # get block layer names
        #------------------------
        start_layer_id = FLAGS.block_size * block_id
        end_layer_id = start_layer_id + FLAGS.block_size
        block_layer_names = valid_layer_names[start_layer_id:end_layer_id]
        is_last_block = valid_layer_names[-1] in block_layer_names
        # print('HG: block_layer_names:', block_layer_names)
        # print('HG: is_last_block:', is_last_block)
        block_layer_names_list.append(block_layer_names)

        #------------------------
        # get pruned layer names so that the pruned block can fit into the test network.
        #------------------------
        pruned_layer_names = valid_layer_names[start_layer_id:end_layer_id]
        config_length = FLAGS.block_size
        # note that last layer cannot be pruned.
        if is_last_block:
            pruned_layer_names.remove(valid_layer_names[-1])
            config_length -= 1

        # if the block is not the first block, prune also the layer before the block
        if block_id != 0:
            pruned_layer_names = [valid_layer_names[start_layer_id - 1]
                                  ] + pruned_layer_names
            config_length += 1
        # print('HG: pruned_layer_names:', pruned_layer_names)
        pruned_layer_names_list.append(pruned_layer_names)

        #------------------------
        # get the pruning configuration for the block
        #------------------------
        # given N=#options, m=block_size. first block has #configs=N^m, other blocks have #configs=N^{m+1},
        block_configurations = list(
            itertools.product(kp_options, repeat=config_length))
        # print('HG: number of block variants:', len(block_configurations))
        if FLAGS.block_config_id >= len(block_configurations):
            print(
                'ERROR: block_config_id=%d should be smaller than number of block variants=%d'
                % (FLAGS.block_config_id, len(block_configurations)))
            return
        block_config = list(block_configurations[FLAGS.block_config_id])
        if not FLAGS.last_conv_pruned:
            # fix the input layer to the block to be 1.0
            if block_id != 0:
                block_config[0] = 1.0
            # fix the last conv layer in the block to be 1.0
            if not is_last_block:
                block_config[-1] = 1.0
        # block_config=[0.5, 0.5, 0.5]
        # print('HG: block confiugrations:', block_config)
        block_config_list.append(block_config)

        #------------------------
        # prepare prune_info with the config
        #------------------------
        for i in xrange(len(block_config)):
            layer_name = pruned_layer_names[i]
            if layer_name in prune_info:
                if prune_info[layer_name]['kp'] != block_config[i]:
                    print(
                        "ERROR: layer name already in prune info but prune_info[layer_name]['kp'] != block_config[i]"
                    )
                    return
            prune_info[layer_name] = {'kp': block_config[i]}

        #------------------------
        # get pruned scopes and shorten scopes
        #------------------------
        for layer_name, layer_config in zip(pruned_layer_names, block_config):
            prune_scope = layer_name_to_prune_scope(layer_name,
                                                    net_name_scope_checkpoint)
            if prune_scope not in prune_scopes:
                prune_scopes.append(prune_scope)
                shorten_scopes.append(
                    layer_name_to_shorten_scope(layer_name,
                                                net_name_scope_checkpoint))
                network_config.append(layer_config)

    #------------------------
    # get variables init value and pruned filter indexes
    #------------------------
    print('HG: prune_scopes', len(prune_scopes), prune_scopes)
    print('HG: network_config', network_config)
    variables_init_value, pruned_filter_indexes = get_init_values_for_pruned_layers(
        prune_scopes, shorten_scopes, network_config)

    # print('HG: prune_info')
    # pprint(prune_info)

    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        _, end_points = network_fn(images, is_training=False)
        # for item in end_points.iteritems():
        #     print(item)
        # return

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        # #########################################
        # # Configure the optimization procedure. #
        # #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        ####################
        # Define the model #
        ####################
        # prune_info = {layer_name_1: {'kp': 0.3, 'inputs': inputs}, layer_name_2:{'kp':0.5}}
        # checkpoint_prune_info = {pruned_layer_names[-1]:{'kp':block_config[-1]}}
        # _, end_points = network_fn_pruned(images,
        #                                 prune_info = checkpoint_prune_info,
        #                                 is_training=True,
        #                                 is_local_train=False,
        #                                 reuse_variables=False,
        #                                 scope = net_name_scope_checkpoint)
        # --------------------------
        # set inputs for prune_info
        # --------------------------
        print_list('pruned_layer_names_list', pruned_layer_names_list)
        for block_id in xrange(num_blocks):
            block_layer_names = block_layer_names_list[block_id]
            pruned_layer_names = pruned_layer_names_list[block_id]

            if block_id == 0:
                block_inputs = images
            else:
                # original inputs might have a different dimension with the required inputs dimension.
                # use pruned_filter_indexes to prune original_inputs
                with tf.name_scope('inputs_selector_' + str(block_id)):
                    # get pruned filter indexes

                    prune_scope = net_name_scope_checkpoint + '/' + pruned_layer_names[
                        0]
                    filter_indexes = pruned_filter_indexes[prune_scope]

                    # get original inputs
                    inputs_layer_id = all_layer_names.index(
                        block_layer_names[0]) - 1
                    inputs_scope = net_name_scope_checkpoint + '/' + all_layer_names[
                        inputs_layer_id]
                    original_inputs = end_points[inputs_scope]

                    # downsample original inputs
                    num_dim = original_inputs.get_shape().as_list()[-1]
                    block_inputs = tf.stack([
                        original_inputs[:, :, :, i]
                        for i in xrange(num_dim) if i not in filter_indexes
                    ],
                                            axis=-1)
                    print('HG: inputs_scope:', inputs_scope,
                          original_inputs.get_shape().as_list(), ' to ',
                          block_inputs.get_shape().as_list())
            # set inputs for this block
            prune_info[block_layer_names[0]]['inputs'] = block_inputs

        print('HG: prune_info:')
        pprint(prune_info)

        # generate the pruned network for training
        _, end_points_pruned = network_fn_pruned(images,
                                                 prune_info=prune_info,
                                                 is_training=True,
                                                 is_local_train=True,
                                                 reuse_variables=False,
                                                 scope=net_name_scope_pruned)
        # generate the pruned network for testing
        logits, _ = network_fn_pruned(test_images,
                                      prune_info=prune_info,
                                      is_training=False,
                                      is_local_train=False,
                                      reuse_variables=True,
                                      scope=net_name_scope_pruned)
        # add correct prediction to the testing network
        correct_prediction = add_correct_prediction(logits, test_labels)

        #############################
        # Specify the loss functions #
        #############################
        total_losses = []
        train_tensors = []
        for block_id in xrange(num_blocks):
            print('\n\nblock_id', block_id)
            block_layer_names = block_layer_names_list[block_id]
            pruned_layer_names = pruned_layer_names_list[block_id]
            is_last_block = (block_id == num_blocks - 1)

            outputs_scope = net_name_scope_checkpoint + '/' + block_layer_names[
                -1]
            outputs_scope_pruned = net_name_scope_pruned + '/' + block_layer_names[
                -1]
            print('HG: outputs_scope:', outputs_scope)

            # add reconstruction loss
            collection_name = 'subgraph_losses_' + str(block_id)
            if outputs_scope not in end_points:
                raise ValueError(
                    'end_points does not contain the outputs_scope: %s',
                    outputs_scope)
            outputs = end_points[outputs_scope]

            if outputs_scope_pruned not in end_points_pruned:
                raise ValueError(
                    'end_points_pruned does not contain the outputs_scope_pruned: %s',
                    outputs_scope_pruned)
            outputs_pruned = end_points_pruned[outputs_scope_pruned]

            # TODO: cannot use l2_loss directory since the outputs and outputs_pruned do not have the same dimension.
            if is_last_block:
                outputs_gnd = outputs
            else:
                with tf.name_scope('output_selector_' + str(block_id)):
                    filter_indexes = pruned_filter_indexes[outputs_scope]
                    num_dim = outputs.get_shape().as_list()[-1]
                    outputs_gnd = tf.stack([
                        outputs[:, :, :, i]
                        for i in xrange(num_dim) if i not in filter_indexes
                    ],
                                           axis=-1)
                    print('HG: ouputs selector:',
                          outputs.get_shape().as_list(), ' to ',
                          outputs_gnd.get_shape().as_list())
            l2_loss = add_l2_loss(outputs_gnd,
                                  outputs_pruned,
                                  add_to_collection=True,
                                  collection_name=collection_name)

            # get regularization loss
            train_scopes = [
                net_name_scope_pruned + '/' + item
                for item in block_layer_names
            ]
            print_list('train_scopes', train_scopes)
            regularization_losses = get_regularization_losses_within_scopes(
                train_scopes,
                add_to_collection=True,
                collection_name=collection_name)
            print_list('regularization_losses', regularization_losses)

            # total loss and its summary
            total_loss = tf.add_n(tf.get_collection(collection_name),
                                  name='total_loss')
            for l in tf.get_collection(collection_name) + [total_loss]:
                tf.summary.scalar(l.op.name + '/summary', l)
            total_losses.append(total_loss)

            #############################
            # Add train operation       #
            #############################

            variables_to_train = get_trainable_variables_within_scopes(
                train_scopes)
            print_list("variables_to_train", variables_to_train)
            train_op = add_train_op(optimizer,
                                    total_loss,
                                    global_step,
                                    var_list=variables_to_train)
            with tf.control_dependencies([train_op]):
                train_tensor = tf.identity(total_loss, name='train_op')
            train_tensors.append(train_tensor)

        # add summary op
        summary_op = tf.summary.merge_all()

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))
        print_list(
            'model_variables but not trainable variables',
            list(
                set(tf.model_variables()).difference(
                    tf.trainable_variables())))
        print_list(
            'global_variables but not model variables',
            list(set(tf.global_variables()).difference(tf.model_variables())))
        print(
            "HG: trainable_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_trainable_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: trainable_variables from " + net_name_scope_pruned + "=",
            len(
                get_trainable_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))
        print(
            "HG: model_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: model_variables from " + net_name_scope_pruned + "=",
            len(
                get_model_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))
        print(
            "HG: global_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_global_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: global_variables from " + net_name_scope_pruned + "=",
            len(
                get_global_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)

        with tf.Session(config=sess_config) as sess:
            ###########################
            # prepare for filewritter #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            if not (FLAGS.continue_training
                    and tf.train.latest_checkpoint(train_dir)):
                ###########################################
                # Restore original model variable values. #
                ###########################################

                variables_to_restore = get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])
                print_list("restore model variables for original",
                           variables_to_restore)
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

                #################################################
                # Init  pruned networks  with  well-trained model #
                #################################################
                variables_to_reinit = get_model_variables_within_scopes(
                    [net_name_scope_pruned + '/'])
                print_list("init pruned model variables for pruned network",
                           variables_to_reinit)
                assign_ops = []
                for v in variables_to_reinit:
                    key = re.sub(net_name_scope_pruned,
                                 net_name_scope_checkpoint, v.op.name)
                    if key in variables_init_value:
                        value = variables_init_value.get(key)
                        # print(key, value)
                        assign_ops.append(
                            tf.assign(v,
                                      tf.convert_to_tensor(value),
                                      validate_shape=True))
                        # v.set_shape(value.shape)
                    else:
                        raise ValueError(
                            "Key not in variables_init_value, key=", key)
                assign_op = tf.group(*assign_ops)
                sess.run(assign_op)

            else:
                # restore all variables from checkpoint
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################

            variables_to_init = get_global_variables_within_scopes(
                sess.run(tf.report_uninitialized_variables()))
            print_list("init uninitialized_variables", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('HG: # of threads=', len(threads))

            # saver for models
            if FLAGS.max_to_keep <= 0:
                max_to_keep = int(2 * FLAGS.max_number_of_steps /
                                  FLAGS.evaluate_every_n_steps)
            else:
                max_to_keep = FLAGS.max_to_keep
            saver = tf.train.Saver(max_to_keep=max_to_keep)

            train_time = 0  # the amount of time spending on sgd training only.
            duration = 0  # used to estimate the training speed
            train_only_cnt = 0  # used to calculate the true training time.
            duration_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):

                # run optional meta data, or summary, while run train tensor
                if i > init_global_step_value:  # FLAGS.max_number_of_steps:
                    # run metadata
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_values = sess.run(train_tensors,
                                               options=run_options,
                                               run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # record summary
                    elif i % FLAGS.summary_every_n_steps == 0:
                        results = sess.run([summary_op] + train_tensors)
                        train_summary, loss_values = results[0], results[-1]
                        train_writer.add_summary(train_summary, i)
                        # print('HG: train with summary')
                        # only run train op
                    else:
                        start_time = time.time()
                        loss_values = sess.run(train_tensors)
                        train_only_cnt += 1
                        duration_cnt += 1
                        train_time += time.time() - start_time
                        duration += time.time() - start_time

                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        # record speed
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, i)
                        info = (
                            '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)'
                        ) % (datetime.now(), i, str(loss_values),
                             examples_per_sec, sec_per_batch)
                        print(info)
                        duration = 0
                        duration_cnt = 0

                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    results = sess.run(
                        [summary_op] +
                        total_losses)  #loss_value = sess.run(total_loss)
                    train_summary, loss_values = results[0], results[-1]
                    train_writer.add_summary(train_summary, i)
                    format_str = ('%s: step %d, loss = %s')
                    print(format_str % (datetime.now(), i, str(loss_values)))
                    info = format_str % (datetime.now(), i, str(loss_values))
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:

                    test_accuracy, run_metadata = evaluate_accuracy(
                        sess,
                        coord,
                        test_dataset.num_samples,
                        test_images,
                        test_labels,
                        test_images,
                        test_labels,
                        correct_prediction,
                        FLAGS.test_batch_size,
                        run_meta=False)
                    summary = tf.Summary()
                    summary.value.add(tag='accuracy',
                                      simple_value=test_accuracy)
                    train_writer.add_summary(summary, i)
                    # if run_meta:
                    # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i)
                    info = ('%s: step %d, test_accuracy = %s') % (
                        datetime.now(), i, str(test_accuracy))
                    print(info)
                    if i == init_global_step_value or is_last_step:
                        # write_log_info(info)
                        log_info += info + '\n'
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    # saver = tf.train.Saver()
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic

            train_speed = train_time * 1.0 / train_only_cnt
            train_time = train_speed * (FLAGS.max_number_of_steps)
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f" % (
                train_time / 60.0, total_time / 60.0)
            print(info)
            log_info += info + '\n\n'
            write_log_info(log_info)
            write_detailed_info(info)
Esempio n. 4
0
def main(_):
    tic = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    # init
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    block_names = valid_block_names
    kept_percentages_dict = get_kept_percentages_dict_from_path(
        FLAGS.checkpoint_path)
    kept_percentages = sorted(map(float, FLAGS.kept_percentages.split(',')))

    # check networks with the kps are pre-trained.
    for kp in kept_percentages:
        if kp not in kept_percentages_dict:
            raise Error('kept_percentage=' + str(kp) + ' not in folder:' +
                        FLAGS.checkpoint_path)

    num_options = len(kept_percentages)
    num_units = len(block_names)
    print('num_options=%d, num_blocks=%d' % (num_options, num_units))
    print('HG: total number of configurations=%d' % (num_options**num_units))

    if FLAGS.configuration_type == 'sample':
        configs = get_sampled_configurations(num_units, num_options,
                                             FLAGS.total_num_configurations)
    elif FLAGS.configuration_type == 'special':
        configs = get_special_configurations(num_units, num_options)
    num_configurations = len(configs)

    #Getting MPI rank integer
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    if rank >= num_configurations:
        print("ERROR: rank(%d) > num_configurations(%d)" %
              (rank, num_configurations))
        return
    FLAGS.configuration_index = FLAGS.start_configuration_index + rank
    config = configs[FLAGS.configuration_index]
    print('HG: kept_percentages=%s, start_config_index=%d, num_configs=%d, rank=%d, config_index=%d' \
           %(str(kept_percentages), FLAGS.start_configuration_index, num_configurations, rank, FLAGS.configuration_index))

    # prepare for training with the specific config
    kept_percentage = config_to_kept_percentage_sequence(
        config, block_names, kept_percentages)
    prune_info = kept_percentage_sequence_to_prune_info(
        kept_percentage, block_names)
    print('HG: prune_info:')
    pprint(prune_info)

    # prepare file system
    results_dir = os.path.join(
        FLAGS.train_dir, "id" +
        str(FLAGS.configuration_index))  #+'_'+str(FLAGS.max_number_of_steps))
    train_dir = os.path.join(results_dir, 'train')

    if (not FLAGS.continue_training) or (
            not tf.train.latest_checkpoint(train_dir)):
        print('Start a new training')
        prepare_file_system(train_dir)
    else:
        print('Continue training')

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir: ' + train_dir + '\n'
    info += 'options:' + str(kept_percentages) + '\n'
    info += 'configuration: ' + str(config) + '\n'
    info += 'kept_percentage: ' + str(kept_percentage)
    print(info)
    write_detailed_info(info)

    with tf.Graph().as_default():

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################
        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        ####################
        # Define the model #
        ####################
        logits_train, _ = network_fn_pruned(images,
                                            prune_info=prune_info,
                                            is_training=True,
                                            is_local_train=False,
                                            reuse_variables=False,
                                            scope=net_name_scope_pruned)

        logits_eval, _ = network_fn_pruned(test_images,
                                           prune_info=prune_info,
                                           is_training=False,
                                           is_local_train=False,
                                           reuse_variables=True,
                                           scope=net_name_scope_pruned)
        cross_entropy = add_cross_entropy(logits_train, labels)
        correct_prediction = add_correct_prediction(logits_eval, test_labels)

        #############################
        # Specify the loss functions #
        #############################
        collection_name = 'subgraph_losses'
        tf.add_to_collection(collection_name, cross_entropy)
        # get regularization loss
        regularization_losses = get_regularization_losses_within_scopes()
        print_list('regularization_losses', regularization_losses)
        # total loss and its summary
        total_loss = tf.add_n(tf.get_collection(collection_name),
                              name='total_loss')
        for l in tf.get_collection(collection_name) + [total_loss]:
            tf.summary.scalar(l.op.name + '/summary', l)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        #############################
        # Add train operation       #
        #############################
        variables_to_train = get_trainable_variables_within_scopes()
        train_op = add_train_op(optimizer,
                                total_loss,
                                global_step,
                                var_list=variables_to_train)
        print_list("variables_to_train", variables_to_train)

        # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned.
        update_ops = get_update_ops_within_scopes()
        print_list("update_ops", update_ops)

        update_ops.append(train_op)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # add summary op
        summary_op = tf.summary.merge_all()

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))
        # print_list('model_variables but not trainable variables', list(set(tf.model_variables()).difference(tf.trainable_variables())))
        # print_list('global_variables but not model variables', list(set(tf.global_variables()).difference(tf.model_variables())))

        # get train scopes for each kept_percentage
        block_names_dict = {}
        for block_name, block_kept_percentage in zip(block_names,
                                                     kept_percentage):
            if block_kept_percentage not in block_names_dict:
                block_names_dict[block_kept_percentage] = []
            block_names_dict[block_kept_percentage].append(block_name)

        #print_list("train_scopes", train_scopes)
        print('HG: block_names_dict:')
        pprint(block_names_dict)

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)
        with tf.Session(config=sess_config) as sess:
            ###########################
            # prepare for filewritter #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            # if restart the training or there is no checkpoint in the train_dir
            if (not FLAGS.continue_training) or (
                    not tf.train.latest_checkpoint(train_dir)):
                #################################################
                # Restore  pruned model variable values. #
                #################################################
                all_variables_to_train = []
                for block_kept_percentage, block_name in block_names_dict.items(
                ):
                    print('HG: kept_percentage', block_kept_percentage)
                    checkpoint_path = os.path.join(
                        FLAGS.checkpoint_path,
                        kept_percentages_dict[block_kept_percentage][0],
                        'train')
                    #    'model.ckpt-'+str(FLAGS.local_train_steps))

                    variables_to_train = {
                        re.sub(
                            net_name_scope_pruned, net_name_scope_pruned +
                            "_p" + str(block_kept_percentage), v.op.name): v
                        for v in get_model_variables_with_block_names(
                            net_name_scope_pruned, block_name)
                    }
                    print_list("restore pruned model variables",
                               variables_to_train.values())
                    load_checkpoint(sess,
                                    checkpoint_path,
                                    var_list=variables_to_train)
                    all_variables_to_train.extend(variables_to_train.values())

                #################################################
                # Restore  orignal  model variable values. #
                #################################################
                variables_to_restore = {
                    re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                           v.op.name): v
                    for v in get_model_variables_within_scopes()
                    if v not in set(all_variables_to_train)
                }
                print_list("restore original model variables",
                           variables_to_restore.values())
                load_checkpoint(sess,
                                checkpoint_path,
                                var_list=variables_to_restore)

            else:
                ###########################################
                ## Restore all variables from checkpoint ##
                ###########################################
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################
            variables_to_init = get_global_variables_within_scopes(
                sess.run(tf.report_uninitialized_variables()))
            print_list("init unitialized variables", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            os.system("mpstat -P ALL 1 > " + mpstat_output_filename +
                      " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
            print('HG: # of threads=', len(threads))

            duration = 0
            duration_cnt = 0
            train_time = 0
            train_only_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):
                #train_step = i+FLAGS.local_train_steps
                train_step = i
                # run optional meta data, or summary, while run train tensor
                if i > init_global_step_value:
                    #if i < FLAGS.max_number_of_steps:

                    # run metadata and train
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_value = sess.run(train_tensor,
                                              options=run_options,
                                              run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # record summary and train
                    elif i % FLAGS.summary_every_n_steps == 0:
                        train_summary, loss_value = sess.run(
                            [summary_op, train_tensor])
                        train_writer.add_summary(train_summary, train_step)

                    # train only
                    else:
                        start_time = time.time()
                        loss_value = sess.run(train_tensor)
                        train_only_cnt += 1
                        train_time += time.time() - start_time
                        duration_cnt += 1
                        duration += time.time() - start_time

                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, train_step)
                        format_str = (
                            '%s: step %d, loss = %.3f (%.1f examples/sec; %.3f sec/batch)'
                        )
                        print(format_str % (datetime.now(), i, loss_value,
                                            examples_per_sec, sec_per_batch))
                        duration = 0
                        duration_cnt = 0

                        info = format_str % (datetime.now(), i, loss_value,
                                             examples_per_sec, sec_per_batch)
                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    train_summary, loss_value = sess.run(
                        [summary_op,
                         total_loss])  #loss_value = sess.run(total_loss)
                    train_writer.add_summary(train_summary, train_step)
                    format_str = ('%s: step %d, loss = %.3f')
                    print(format_str % (datetime.now(), i, loss_value))
                    info = format_str % (datetime.now(), i, loss_value)
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:

                    test_accuracy, run_metadata = evaluate_accuracy(
                        sess,
                        coord,
                        test_dataset.num_samples,
                        test_images,
                        test_labels,
                        test_images,
                        test_labels,
                        correct_prediction,
                        FLAGS.test_batch_size,
                        run_meta=False)
                    summary = tf.Summary()
                    summary.value.add(tag='accuracy',
                                      simple_value=test_accuracy)
                    train_writer.add_summary(summary, train_step)

                    info = ('%s: step %d, test_accuracy = %.6f') % (
                        datetime.now(), train_step, test_accuracy)
                    print(info)
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic

            train_speed = train_time / train_only_cnt
            train_time = (FLAGS.max_number_of_steps) * train_speed
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f \n" % (
                train_time / 60.0, total_time / 60.0)

            print(info)
            write_detailed_info(info)
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    # init
    print('HG: Train pruned blocks for all valid layers concurrently')
    block_names = valid_block_names
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    kept_percentages = sorted(
        [float(x) for x in FLAGS.kept_percentages.split(',')])
    print_list('kept_percentages', kept_percentages)

    # prepare file system
    results_dir = os.path.join(FLAGS.train_dir, 'kp' + FLAGS.kept_percentages)
    train_dir = os.path.join(results_dir, 'train')
    if (not FLAGS.continue_training) or (
            not tf.train.latest_checkpoint(train_dir)):
        print('Start a new training')
        prepare_file_system(train_dir)
    else:
        print('Continue training')

    def write_log_info(info):
        with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f:
            f.write(info + '\n')

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir
    log_info = info + '\n'
    write_detailed_info(info)

    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        _, end_points = network_fn(images, is_training=False)
        # for item in end_points.iteritems():
        #     print(item)
        # return

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        # #########################################
        # # Configure the optimization procedure. #
        # #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        ####################
        # Define the model #
        ####################
        # each kept_percentage corresponds to a pruned network.
        train_tensors = []
        total_losses = []
        pruned_net_name_scopes = []
        correct_predictions = []
        prune_infos = []
        for kept_percentage in kept_percentages:

            prune_info = kept_percentage_sequence_to_prune_info(
                kept_percentage, block_names)
            set_prune_info_inputs(prune_info, end_points)
            prune_infos.append(prune_info)

            #  the pruned network scope
            net_name_scope_pruned = FLAGS.net_name_scope_pruned + '_p' + str(
                kept_percentage)
            pruned_net_name_scopes.append(net_name_scope_pruned)

            # generate the pruned network for training
            _, end_points_pruned = network_fn_pruned(
                images,
                prune_info=prune_info,
                is_training=True,
                is_local_train=True,
                reuse_variables=False,
                scope=net_name_scope_pruned)
            # generate the pruned network for testing
            logits, _ = network_fn_pruned(test_images,
                                          prune_info=prune_info,
                                          is_training=False,
                                          is_local_train=False,
                                          reuse_variables=True,
                                          scope=net_name_scope_pruned)
            # add correct prediction to the testing network
            correct_prediction = add_correct_prediction(logits, test_labels)
            correct_predictions.append(correct_prediction)

            #############################
            # Specify the loss functions #
            #############################
            for i, block_name in enumerate(block_names):
                print('HG: i=%d, block_name=%s' % (i, block_name))
                # add l2 losses
                appendix = '_p' + str(kept_percentage) + '_' + str(i)
                collection_name = 'subgraph_losses' + appendix
                # print("HG: collection_name=", collection_name)

                outputs = end_points[block_name]
                outputs_pruned = end_points_pruned[block_name]
                l2_loss = add_l2_loss(outputs,
                                      outputs_pruned,
                                      add_to_collection=True,
                                      collection_name=collection_name)

                # get regularization loss
                regularization_losses = get_regularization_losses_with_block_names(net_name_scope_pruned, \
                    block_name, add_to_collection=True, collection_name=collection_name)
                print_list('regularization_losses', regularization_losses)

                # total loss and its summary
                total_loss = tf.add_n(tf.get_collection(collection_name),
                                      name='total_loss')
                for l in tf.get_collection(collection_name) + [total_loss]:
                    tf.summary.scalar(l.op.name + appendix + '/summary', l)
                total_losses.append(total_loss)

                #############################
                # Add train operation       #
                #############################
                variables_to_train = get_trainable_variables_with_block_names(
                    net_name_scope_pruned, block_name)
                print_list("variables_to_train", variables_to_train)

                # add train_op
                if i == 0 and kept_percentage == kept_percentages[0]:
                    global_step_tmp = global_step
                else:
                    global_step_tmp = tf.Variable(0,
                                                  trainable=False,
                                                  name='global_step' +
                                                  appendix)
                train_op = add_train_op(optimizer,
                                        total_loss,
                                        global_step_tmp,
                                        var_list=variables_to_train)

                # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned.
                update_ops = get_update_ops_with_block_names(
                    net_name_scope_pruned, block_name)
                print_list("update_ops", update_ops)

                update_ops.append(train_op)
                update_op = tf.group(*update_ops)
                with tf.control_dependencies([update_op]):
                    train_tensor = tf.identity(total_loss,
                                               name='train_op' + appendix)
                    train_tensors.append(train_tensor)

        # add summary op
        summary_op = tf.summary.merge_all()

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))
        print_list(
            'model_variables but not trainable variables',
            list(
                set(tf.model_variables()).difference(
                    tf.trainable_variables())))
        print_list(
            'global_variables but not model variables',
            list(set(tf.global_variables()).difference(tf.model_variables())))
        print(
            "HG: trainable_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_trainable_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: trainable_variables from " + net_name_scope_pruned + "=",
            len(
                get_trainable_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))
        print(
            "HG: model_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: model_variables from " + net_name_scope_pruned + "=",
            len(
                get_model_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))
        print(
            "HG: global_variables from " + net_name_scope_checkpoint + "=",
            len(
                get_global_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])))
        print(
            "HG: global_variables from " + net_name_scope_pruned + "=",
            len(
                get_global_variables_within_scopes(
                    [net_name_scope_pruned + '/'])))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)

        with tf.Session(config=sess_config) as sess:
            ###########################
            # prepare for filewritter #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            if (not FLAGS.continue_training) or (
                    not tf.train.latest_checkpoint(train_dir)):
                ###########################################
                # Restore original model variable values. #
                ###########################################
                variables_to_restore = get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])
                print_list("restore model variables for original",
                           variables_to_restore)
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

                #################################################
                # Init  pruned networks  with  well-trained model #
                #################################################

                for i in range(len(pruned_net_name_scopes)):
                    net_name_scope_pruned = pruned_net_name_scopes[i]
                    print('net_name_scope_pruned=', net_name_scope_pruned)

                    ## init pruned variables .
                    kept_percentage = kept_percentages[i]
                    prune_info = prune_infos[i]
                    variables_init_value = get_pruned_kernel_matrix(
                        sess, prune_info, net_name_scope_checkpoint)
                    reinit_scopes = [
                        re.sub(net_name_scope_checkpoint,
                               net_name_scope_pruned, name)
                        for name in variables_init_value.keys()
                    ]
                    variables_to_reinit = get_model_variables_within_scopes(
                        reinit_scopes)
                    print_list("Initialize pruned variables",
                               variables_to_reinit)

                    assign_ops = []
                    for v in variables_to_reinit:
                        key = re.sub(net_name_scope_pruned,
                                     net_name_scope_checkpoint, v.op.name)
                        if key in variables_init_value:
                            value = variables_init_value.get(key)
                            # print(key, value)
                            assign_ops.append(
                                tf.assign(v,
                                          tf.convert_to_tensor(value),
                                          validate_shape=True))
                            # v.set_shape(value.shape)
                        else:
                            raise ValueError(
                                "Key not in variables_init_value, key=", key)
                    assign_op = tf.group(*assign_ops)
                    sess.run(assign_op)

                    #################################################
                    # Restore unchanged model variable. #
                    #################################################
                    variables_to_restore = {
                        re.sub(net_name_scope_pruned,
                               net_name_scope_checkpoint, v.op.name): v
                        for v in get_model_variables_within_scopes(
                            [net_name_scope_pruned + '/'])
                        if v not in variables_to_reinit
                    }
                    print_list(
                        "restore model variables for " + net_name_scope_pruned,
                        variables_to_restore.values())
                    load_checkpoint(sess,
                                    FLAGS.checkpoint_path,
                                    var_list=variables_to_restore)
            else:
                # restore all variables from checkpoint
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################
            uninitialized_variables = [
                x.decode('utf-8')
                for x in sess.run(tf.report_uninitialized_variables())
            ]
            print_list('uninitialized variables', uninitialized_variables)
            variables_to_init = [
                v for v in tf.global_variables()
                if v.name.split(':')[0] in set(uninitialized_variables)
            ]
            #get_global_variables_within_scopes(uninitialized_variables)
            print_list("variables_to_init", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            os.system("mpstat -P ALL 1 > " + mpstat_output_filename +
                      " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('HG: # of threads=', len(threads))

            # saver for models
            if FLAGS.max_to_keep <= 0:
                max_to_keep = int(2 * FLAGS.max_number_of_steps /
                                  FLAGS.evaluate_every_n_steps)
            else:
                max_to_keep = FLAGS.max_to_keep
            saver = tf.train.Saver(max_to_keep=max_to_keep)

            train_time = 0  # the amount of time spending on sgd training only.
            duration = 0  # used to estimate the training speed
            train_only_cnt = 0  # used to calculate the true training time.
            duration_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):

                # run optional meta data, or summary, while run train tensor
                if i > init_global_step_value:  # FLAGS.max_number_of_steps:

                    # run metadata
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_values = sess.run(train_tensors,
                                               options=run_options,
                                               run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # record summary
                    elif i % FLAGS.summary_every_n_steps == 0:
                        results = sess.run([summary_op] + train_tensors)
                        train_summary, loss_values = results[0], results[1:]
                        train_writer.add_summary(train_summary, i)
                        # print('HG: train with summary')
                        # only run train op
                    else:
                        start_time = time.time()
                        loss_values = sess.run(train_tensors)
                        train_only_cnt += 1
                        duration_cnt += 1
                        train_time += time.time() - start_time
                        duration += time.time() - start_time

                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        # record speed
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, i)
                        info = (
                            '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)'
                        ) % (datetime.now(), i, str(loss_values),
                             examples_per_sec, sec_per_batch)
                        print(info)
                        duration = 0
                        duration_cnt = 0

                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    results = sess.run(
                        [summary_op] +
                        total_losses)  #loss_value = sess.run(total_loss)
                    train_summary, loss_values = results[0], results[1:]
                    train_writer.add_summary(train_summary, i)
                    format_str = ('%s: step %d, loss = %s')
                    print(format_str % (datetime.now(), i, str(loss_values)))
                    info = format_str % (datetime.now(), i, str(loss_values))
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:

                    # test accuracy; each kept_percentage corresponds to a pruned network, and thus an accuracy.
                    test_accuracies = []
                    for p in range(len(kept_percentages)):
                        kept_percentage = kept_percentages[p]
                        appendix = '_p' + str(kept_percentage)
                        correct_prediction = correct_predictions[p]
                        # run_meta = (i==FLAGS.evaluate_every_n_steps)&&(p==0)
                        test_accuracy, run_metadata = evaluate_accuracy(
                            sess,
                            coord,
                            test_dataset.num_samples,
                            test_images,
                            test_labels,
                            test_images,
                            test_labels,
                            correct_prediction,
                            FLAGS.test_batch_size,
                            run_meta=False)
                        summary = tf.Summary()
                        summary.value.add(tag='accuracy' + appendix,
                                          simple_value=test_accuracy)
                        train_writer.add_summary(summary, i)
                        test_accuracies.append(
                            (kept_percentage, test_accuracy))
                    # if run_meta:
                    # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i)
                    acc_str = '[' + ', '.join([
                        '(%s, %.6f)' % (str(kp), acc)
                        for kp, acc in test_accuracies
                    ]) + ']'
                    info = ('%s: step %d, test_accuracy = %s') % (
                        datetime.now(), i, str(acc_str))
                    print(info)
                    if i == 0 or is_last_step:
                        # write_log_info(info)
                        log_info += info + '\n'
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    # saver = tf.train.Saver()
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic

            train_speed = train_time / train_only_cnt
            train_time = (FLAGS.max_number_of_steps) * train_speed
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f \n" % (
                train_time / 60.0, total_time / 60.0)
            print(info)
            log_info += info
            write_log_info(log_info)
            write_detailed_info(info)
Esempio n. 6
0
def main(_):
    tic = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)

    # init
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    indexed_prune_scopes_for_units = valid_indexed_prune_scopes_for_units
    kept_percentages = sorted(map(float, FLAGS.kept_percentages.split(',')))

    num_options = len(kept_percentages)
    num_units = len(indexed_prune_scopes_for_units)
    print('num_options=%d, num_blocks=%d' % (num_options, num_units))
    print('HG: total number of configurations=%d' % (num_options**num_units))

    # find the  configurations to evaluate
    if FLAGS.configuration_type == 'sample':
        configs = get_sampled_configurations(num_units, num_options,
                                             FLAGS.total_num_configurations)
    elif FLAGS.configuration_type == 'special':
        configs = get_special_configurations(num_units, num_options)
    num_configurations = len(configs)

    #Getting MPI rank integer
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    if rank >= num_configurations:
        print("ERROR: rank(%d) > num_configurations(%d)" %
              (rank, num_configurations))
        return
    FLAGS.configuration_index = FLAGS.start_configuration_index + rank
    config = configs[FLAGS.configuration_index]
    if FLAGS.configuration_index >= num_configurations:
        print('configuration_index >= num_configurations',
              FLAGS.configuration_index, num_configurations)
        return
    print('HG: kept_percentages=%s, start_config_index=%d, num_configs=%d, rank=%d, config_index=%d' \
           %(str(kept_percentages), FLAGS.start_configuration_index, num_configurations, rank, FLAGS.configuration_index))

    # prepare for evaluate the number of parameters with the specific config
    combination = config
    indexed_prune_scopes, kept_percentage = config_to_indexed_prune_scopes(
        combination, indexed_prune_scopes_for_units, kept_percentages)
    prune_scopes = indexed_prune_scopes_to_prune_scopes(
        indexed_prune_scopes, net_name_scope_checkpoint)
    shorten_scopes = indexed_prune_scopes_to_shorten_scopes(
        indexed_prune_scopes, net_name_scope_checkpoint)
    reinit_scopes = [
        re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v)
        for v in prune_scopes + shorten_scopes
    ]

    # prepare file system
    eval_dir = os.path.join(FLAGS.train_dir,
                            "id" + str(FLAGS.configuration_index))
    prepare_file_system(eval_dir)

    # functions to write logs
    # def write_log_info(info):
    #     with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f:
    #             f.write(info+'\n')
    def write_detailed_info(info):
        with open(os.path.join(eval_dir, 'eval_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'eval_dir:' + eval_dir + '\n'
    info += 'options:' + FLAGS.kept_percentages + '\n'
    info += 'combination: ' + str(combination) + '\n'
    info += 'indexed_prune_scopes: ' + str(indexed_prune_scopes) + '\n'
    info += 'kept_percentage: ' + str(kept_percentage)
    print(info)
    write_detailed_info(info)

    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)

        ######################
        # Select the network#
        ######################
        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            num_classes=(test_dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        ####################
        # Define the model #
        ####################
        prune_info = indexed_prune_scopes_to_prune_info(
            indexed_prune_scopes, kept_percentage)
        print('HG: prune_info:')
        pprint(prune_info)

        logits, _ = network_fn_pruned(test_images,
                                      prune_info=prune_info,
                                      is_training=False,
                                      is_local_train=False,
                                      reuse_variables=False,
                                      scope=net_name_scope_pruned)

        correct_prediction = add_correct_prediction(logits, test_labels)
        model_variables = get_model_variables_within_scopes()
        name_size_strs = [
            x.op.name + '\t' + str(x.get_shape().as_list())
            for x in model_variables
        ]
        write_detailed_info('\n'.join(name_size_strs))

    total_time = time.time() - tic
    info = 'Evaluate network total_time(s)=%.3f \n' % (total_time)
    print(info)
    write_detailed_info(info)
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    # initialize constants
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    kp_options = sorted([float(x) for x in FLAGS.kept_percentages.split(',')])
    num_options = len(kp_options)
    pruned_layer_names = valid_layer_names[:-1]
    num_units = len(pruned_layer_names)
    print_list('kp_options', kp_options)
    print('HG: num_options=%d, num_units=%d' % (num_options, num_units))
    print('HG: total number of configurations=%d' % (num_options**num_units))

    # find the  configurations to evaluate
    configs = get_sampled_configurations(num_units, num_options,
                                         FLAGS.total_num_configs)
    num_configs = len(configs)
    # print('HG: config_type=', FLAGS.config_type, ', num_configs=', num_configs)

    #Getting MPI rank integer
    # comm = MPI.COMM_WORLD
    # rank = comm.Get_rank() # use rank as offset
    rank = 0
    config_id = FLAGS.start_config_id + rank
    print('HG: start_config_index=%d, rank=%d,  config_index=%d' %
          (FLAGS.start_config_id, rank, config_id))
    if config_id >= num_configs:
        print("ERROR: config_id(%d) >= num_configs(%d)" %
              (config_id, num_configs))
        return

    # get the specific configuration
    config = configs[config_id]
    config = [kp_options[i] for i in config]
    if not FLAGS.last_conv_pruned:
        # if the last conv in a block is not pruned. reset the config
        for i in xrange(len(config)):
            if (i + 1) % FLAGS.block_size == 0:
                config[i] = 1.0
    print('HG: selected config=', config)

    # prepare for training with the specific config
    prune_info = {}
    for i in xrange(len(config)):
        layer_name = pruned_layer_names[i]
        prune_info[layer_name] = {'kp': config[i]}

    # prepare file system
    if FLAGS.last_conv_pruned:
        foldername = 'last_conv_pruned'
    else:
        foldername = 'last_conv_unpruned'
    results_dir = os.path.join(FLAGS.train_dir, foldername,
                               'id' + str(config_id))
    train_dir = os.path.join(results_dir, 'train')

    if not (FLAGS.continue_training and tf.train.latest_checkpoint(train_dir)):
        prune_scopes = [
            layer_name_to_prune_scope(layer_name, net_name_scope_checkpoint)
            for layer_name in pruned_layer_names
        ]
        shorten_scopes = [
            layer_name_to_shorten_scope(layer_name, net_name_scope_checkpoint)
            for layer_name in pruned_layer_names
        ]
        variables_init_value, _ = get_init_values_for_pruned_layers(
            prune_scopes, shorten_scopes, config)
        reinit_scopes = [
            re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v)
            for v in prune_scopes + shorten_scopes
        ]

        prepare_file_system(train_dir)

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir + '\n'
    info += 'kp_options:' + str(kp_options) + '\n'
    info += 'configuration: ' + str(config) + '\n'
    # print(info)
    write_detailed_info(info)

    with tf.Graph().as_default():

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            prune_info=prune_info,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        print('HG: prune_info:')
        pprint(prune_info)

        ####################
        # Define the model #
        ####################
        logits_train, _ = network_fn_pruned(images,
                                            is_training=True,
                                            is_local_train=False,
                                            reuse_variables=False,
                                            scope=net_name_scope_pruned)
        logits_eval, _ = network_fn_pruned(test_images,
                                           is_training=False,
                                           is_local_train=False,
                                           reuse_variables=True,
                                           scope=net_name_scope_pruned)

        cross_entropy = add_cross_entropy(logits_train, labels)
        correct_prediction = add_correct_prediction(logits_eval, test_labels)

        #############################
        # Specify the loss function #
        #############################
        tf.add_to_collection('subgraph_losses', cross_entropy)
        # get regularization loss
        regularization_losses = get_regularization_losses_within_scopes()
        print_list('regularization_losses', regularization_losses)

        # total loss and its summary
        total_loss = tf.add_n(tf.get_collection('subgraph_losses'),
                              name='total_loss')
        for l in tf.get_collection('subgraph_losses') + [total_loss]:
            tf.summary.scalar(l.op.name + '/summary', l)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        #############################
        # Add train operation       #
        #############################
        variables_to_train = get_trainable_variables_within_scopes()
        train_op = add_train_op(optimizer,
                                total_loss,
                                global_step,
                                var_list=variables_to_train)
        print_list("variables_to_train", variables_to_train)

        # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned.
        update_ops = get_update_ops_within_scopes()
        print_list("update_ops", update_ops)

        # add train_tensor
        update_ops.append(train_op)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # add summary op
        summary_op = tf.summary.merge_all()

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)
        with tf.Session(config=sess_config) as sess:
            ###########################
            # Prepare for filewriter. #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            # if restart the training and there is a checkpoint in the train_dir
            if FLAGS.continue_training and tf.train.latest_checkpoint(
                    train_dir):
                ###########################################
                ## Restore all variables from checkpoint ##
                ###########################################
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)
            else:
                #########################################
                # Reinit  pruned model variable  #
                #########################################
                variables_to_reinit = get_model_variables_within_scopes(
                    reinit_scopes)
                print_list("Initialize pruned variables", variables_to_reinit)
                assign_ops = []
                for v in variables_to_reinit:
                    key = re.sub(net_name_scope_pruned,
                                 net_name_scope_checkpoint, v.op.name)
                    if key in variables_init_value:
                        value = variables_init_value.get(key)
                        # print(key, value)
                        assign_ops.append(
                            tf.assign(v,
                                      tf.convert_to_tensor(value),
                                      validate_shape=True))
                        # v.set_shape(value.shape)
                    else:
                        raise ValueError(
                            "Key not in variables_init_value, key=", key)
                assign_op = tf.group(*assign_ops)
                sess.run(assign_op)

                #################################################
                # Restore unchanged model variable. #
                #################################################
                variables_to_restore = {
                    re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                           v.op.name): v
                    for v in get_model_variables_within_scopes()
                    if v not in variables_to_reinit
                }
                print_list("restore model variables",
                           variables_to_restore.values())
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################
            variables_to_init = get_global_variables_within_scopes(
                sess.run(tf.report_uninitialized_variables()))
            print_list("init unitialized variables", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('Exit: init_global_step_value (%d) >= FLAG.max_number_of_steps (%d)' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
            print('HG: # of threads=', len(threads))

            duration = 0
            duration_cnt = 0
            train_time = 0
            train_only_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):
                # run optional meta data, or summary, while run train tensor
                #if i < FLAGS.max_number_of_steps:
                if i > init_global_step_value:
                    # train while run metadata
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_value = sess.run(train_tensor,
                                              options=run_options,
                                              run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # train while record summary
                    elif i % FLAGS.summary_every_n_steps == 0:
                        train_summary, loss_value = sess.run(
                            [summary_op, train_tensor])
                        train_writer.add_summary(train_summary, i)

                    # train only
                    else:
                        start_time = time.time()
                        loss_value = sess.run(train_tensor)
                        train_only_cnt += 1
                        train_time += time.time() - start_time
                        duration_cnt += 1
                        duration += time.time() - start_time

                    # log loss information
                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, i)
                        format_str = (
                            '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)'
                        )
                        print(format_str % (datetime.now(), i, str(loss_value),
                                            examples_per_sec, sec_per_batch))
                        duration = 0
                        duration_cnt = 0

                        info = format_str % (datetime.now(), i,
                                             str(loss_value), examples_per_sec,
                                             sec_per_batch)
                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    train_summary, loss_value = sess.run(
                        [summary_op,
                         total_loss])  #loss_value = sess.run(total_loss)
                    train_writer.add_summary(train_summary, i)
                    format_str = ('%s: step %d, loss = %s')
                    print(format_str % (datetime.now(), i, str(loss_value)))
                    info = format_str % (datetime.now(), i, str(loss_value))
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:
                    #run_meta = (i==FLAGS.evaluate_every_n_steps)
                    test_accuracy, run_metadata = evaluate_accuracy(
                        sess,
                        coord,
                        test_dataset.num_samples,
                        test_images,
                        test_labels,
                        test_images,
                        test_labels,
                        correct_prediction,
                        FLAGS.test_batch_size,
                        run_meta=False)
                    summary = tf.Summary()
                    summary.value.add(tag='accuracy',
                                      simple_value=test_accuracy)
                    train_writer.add_summary(summary, i)
                    #if run_meta:
                    #    eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i)

                    info = ('%s: step %d, test_accuracy = %.6f') % (
                        datetime.now(), i, test_accuracy)
                    print(info)
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    #saver = tf.train.Saver(var_list=get_model_variables_within_scopes([net_name_scope_pruned+'/']))
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic
            train_speed = train_time * 1.0 / train_only_cnt
            train_time = train_speed * (
                FLAGS.max_number_of_steps
            )  # - init_global_step_value) #/train_only_cnt
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f" % (
                train_time / 60.0, total_time / 60.0)
            print(info)
            write_detailed_info(info)
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    # init
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    indexed_prune_scopes_for_units = valid_indexed_prune_scopes_for_units
    kept_percentage = FLAGS.kept_percentage

    # set the configuration: should be a 16-length vector
    config = [1.0] * len(indexed_prune_scopes_for_units)
    config[FLAGS.block_id] = kept_percentage
    print("config:", config)

    # prepare for training with the specific config
    indexed_prune_scopes = indexed_prune_scopes_for_units[FLAGS.block_id]
    prune_info = indexed_prune_scopes_to_prune_info(indexed_prune_scopes,
                                                    kept_percentage)
    print("prune_info:", prune_info)

    # prepare file system
    results_dir = os.path.join(
        FLAGS.train_dir,
        'id' + str(FLAGS.block_id))  #+'_'+str(FLAGS.max_number_of_steps))
    train_dir = os.path.join(results_dir, 'kp' + str(kept_percentage))

    prune_scopes = indexed_prune_scopes_to_prune_scopes(
        indexed_prune_scopes, net_name_scope_checkpoint)
    shorten_scopes = indexed_prune_scopes_to_shorten_scopes(
        indexed_prune_scopes, net_name_scope_checkpoint)
    variables_init_value = get_init_values_for_pruned_layers(
        prune_scopes, shorten_scopes, kept_percentage)
    reinit_scopes = [
        re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v)
        for v in prune_scopes + shorten_scopes
    ]

    prepare_file_system(train_dir)

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'eval_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir + '\n'
    info += 'block_id:' + str(FLAGS.block_id) + '\n'
    info += 'configuration: ' + str(config) + '\n'
    info += 'indexed_prune_scopes: ' + str(indexed_prune_scopes) + '\n'
    info += 'kept_percentage: ' + str(kept_percentage)
    print(info)
    write_detailed_info(info)

    with tf.Graph().as_default():

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        # dataset = dataset_factory.get_dataset(
        #     FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        # batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        # images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            prune_info=prune_info,
            num_classes=(test_dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        print('HG: prune_info:')
        pprint(prune_info)

        ####################
        # Define the model #
        ####################
        # logits_train, _ = network_fn_pruned(images, is_training=True, is_local_train=False, reuse_variables=False, scope = net_name_scope_pruned)
        logits_eval, _ = network_fn_pruned(test_images,
                                           is_training=False,
                                           is_local_train=False,
                                           reuse_variables=False,
                                           scope=net_name_scope_pruned)
        correct_prediction = add_correct_prediction(logits_eval, test_labels)

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)
        with tf.Session(config=sess_config) as sess:
            ###########################
            # Prepare for filewriter. #
            ###########################
            # train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            #########################################
            # Reinit  pruned model variable  #
            #########################################
            variables_to_reinit = get_model_variables_within_scopes(
                reinit_scopes)
            print_list("Initialize pruned variables", variables_to_reinit)
            assign_ops = []
            for v in variables_to_reinit:
                key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                             v.op.name)
                if key in variables_init_value:
                    value = variables_init_value.get(key)
                    # print(key, value)
                    assign_ops.append(
                        tf.assign(v,
                                  tf.convert_to_tensor(value),
                                  validate_shape=True))
                    # v.set_shape(value.shape)
                else:
                    raise ValueError("Key not in variables_init_value, key=",
                                     key)
            assign_op = tf.group(*assign_ops)
            sess.run(assign_op)

            #################################################
            # Restore unchanged model variable. #
            #################################################
            variables_to_restore = {
                re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                       v.op.name): v
                for v in get_model_variables_within_scopes()
                if v not in variables_to_reinit
            }
            print_list("restore model variables",
                       variables_to_restore.values())
            load_checkpoint(sess,
                            FLAGS.checkpoint_path,
                            var_list=variables_to_restore)

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            # saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
            print('HG: # of threads=', len(threads))

            eval_time = -1 * time.time()
            test_accuracy, run_metadata = evaluate_accuracy(
                sess,
                coord,
                test_dataset.num_samples,
                test_images,
                test_labels,
                test_images,
                test_labels,
                correct_prediction,
                FLAGS.test_batch_size,
                run_meta=False)
            eval_time += time.time()

            info = ('%s: test_accuracy = %.6f') % (datetime.now(),
                                                   test_accuracy)
            print(info)
            write_detailed_info(info)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic

            info = "HG: training time(min): %.1f, total time(min): %.1f" % (
                eval_time / 60.0, total_time / 60.0)
            print(info)
            write_detailed_info(info)