def main(_):
    ### define path and hyper-parameter
    model_name = 'ResNet'
    Learning_rate = 1e-1

    batch_size = 128
    val_batch_size = 200
    train_epoch = 100
    init_epoch = 40 if FLAGS.Distillation == 'FitNet' or FLAGS.Distillation == 'FSP' or FLAGS.Distillation == 'AB' else 0

    total_epoch = init_epoch + train_epoch
    weight_decay = 5e-4

    should_log = 200
    save_summaries_secs = 20
    tf.logging.set_verbosity(tf.logging.INFO)
    gpu_num = '0'

    if FLAGS.Distillation == 'None':
        FLAGS.Distillation = None

    (train_images, train_labels), (val_images, val_labels) = load_data()

    dataset_len, *image_size = train_images.shape
    num_label = int(np.max(train_labels) + 1)
    with tf.Graph().as_default() as graph:
        # make placeholder for inputs
        image_ph = tf.placeholder(tf.uint8, [None] + image_size)
        label_ph = tf.placeholder(tf.int32, [None])
        is_training_ph = tf.placeholder(tf.int32, [])
        is_training = tf.equal(is_training_ph, 1)

        # pre-processing
        image = pre_processing(image_ph, is_training)
        label = tf.contrib.layers.one_hot_encoding(label_ph,
                                                   num_label,
                                                   on_value=1.0)

        # make global step
        global_step = tf.train.create_global_step()
        epoch = tf.floor_div(
            tf.cast(global_step, tf.float32) * batch_size, dataset_len)
        max_number_of_steps = int(dataset_len * total_epoch) // batch_size + 1

        # make learning rate scheduler
        LR = learning_rate_scheduler(Learning_rate,
                                     [epoch, init_epoch, train_epoch],
                                     [0.3, 0.6, 0.8], 0.1)

        ## load Net
        class_loss, accuracy = MODEL(model_name,
                                     FLAGS.main_scope,
                                     weight_decay,
                                     image,
                                     label,
                                     is_training,
                                     reuse=False,
                                     drop=True,
                                     Distillation=FLAGS.Distillation)

        #make training operator
        if FLAGS.Distillation != 'DML':
            train_op = op_util.Optimizer_w_Distillation(
                class_loss, LR, epoch, init_epoch, global_step,
                FLAGS.Distillation)
        else:
            teacher_train_op, train_op = op_util.Optimizer_w_DML(
                class_loss, LR, epoch, init_epoch, global_step)

        ## collect summary ops for plotting in tensorboard
        summary_op = tf.summary.merge(tf.get_collection(
            tf.GraphKeys.SUMMARIES),
                                      name='summary_op')

        ## make placeholder and summary op for training and validation results
        train_acc_place = tf.placeholder(dtype=tf.float32)
        val_acc_place = tf.placeholder(dtype=tf.float32)
        val_summary = [
            tf.summary.scalar('accuracy/training_accuracy', train_acc_place),
            tf.summary.scalar('accuracy/validation_accuracy', val_acc_place)
        ]
        val_summary_op = tf.summary.merge(list(val_summary),
                                          name='val_summary_op')

        ## start training
        train_writer = tf.summary.FileWriter('%s' % FLAGS.train_dir,
                                             graph,
                                             flush_secs=save_summaries_secs)
        config = ConfigProto()
        config.gpu_options.visible_device_list = gpu_num
        config.gpu_options.allow_growth = True

        val_itr = len(val_labels) // val_batch_size
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            if FLAGS.Distillation is not None and FLAGS.Distillation != 'DML':
                ## if Distillation is True, load and assign teacher's variables
                ## this mechanism is slower but easier to modifier than load checkpoint
                global_variables = tf.get_collection('Teacher')
                teacher = sio.loadmat(home_path +
                                      '/pre_trained/%s.mat' % FLAGS.teacher)
                n = 0
                for v in global_variables:
                    if teacher.get(v.name[:-2]) is not None:
                        sess.run(
                            v.assign(teacher[v.name[:-2]].reshape(
                                *v.get_shape().as_list())))
                        n += 1
                print('%d Teacher params assigned' % n)

            sum_train_accuracy = []
            time_elapsed = []
            total_loss = []
            idx = list(range(train_labels.shape[0]))
            shuffle(idx)
            epoch_ = 0
            for step in range(max_number_of_steps):
                start_time = time.time()

                ## feed data
                if FLAGS.Distillation == 'DML':
                    sess.run(
                        [teacher_train_op],
                        feed_dict={
                            image_ph: train_images[idx[:batch_size]],
                            label_ph:
                            np.squeeze(train_labels[idx[:batch_size]]),
                            is_training_ph: 1
                        })

                tl, log, train_acc = sess.run(
                    [train_op, summary_op, accuracy],
                    feed_dict={
                        image_ph: train_images[idx[:batch_size]],
                        label_ph: np.squeeze(train_labels[idx[:batch_size]]),
                        is_training_ph: 1
                    })

                time_elapsed.append(time.time() - start_time)
                total_loss.append(tl)
                sum_train_accuracy.append(train_acc)
                idx[:batch_size] = []
                if len(idx) < batch_size:
                    idx_ = list(range(train_labels.shape[0]))
                    shuffle(idx_)
                    idx += idx_

                step += 1
                if (step * batch_size) // dataset_len >= init_epoch + epoch_:
                    ## do validation
                    sum_val_accuracy = []
                    for i in range(val_itr):
                        val_batch = val_images[i * val_batch_size:(i + 1) *
                                               val_batch_size]
                        acc = sess.run(
                            accuracy,
                            feed_dict={
                                image_ph:
                                val_batch,
                                label_ph:
                                np.squeeze(
                                    val_labels[i * val_batch_size:(i + 1) *
                                               val_batch_size]),
                                is_training_ph:
                                0
                            })
                        sum_val_accuracy.append(acc)

                    sum_train_accuracy = np.mean(sum_train_accuracy) * 100 if (
                        step * batch_size) // dataset_len > init_epoch else 1.
                    sum_val_accuracy = np.mean(sum_val_accuracy) * 100
                    tf.logging.info(
                        'Epoch %s Step %s - train_Accuracy : %.2f%%  val_Accuracy : %.2f%%'
                        % (str(epoch_).rjust(3, '0'), str(step).rjust(
                            6, '0'), sum_train_accuracy, sum_val_accuracy))

                    result_log = sess.run(val_summary_op,
                                          feed_dict={
                                              train_acc_place:
                                              sum_train_accuracy,
                                              val_acc_place: sum_val_accuracy
                                          })
                    if (step * batch_size
                        ) // dataset_len == init_epoch and init_epoch > 0:
                        #re-initialize Momentum for fair comparison w/ initialization and multi-task learning methods
                        for v in global_variables:
                            if v.name[:-len('Momentum:0')] == 'Momentum:0':
                                sess.run(
                                    v.assign(
                                        np.zeros(*v.get_shape().as_list())))

                    if step == max_number_of_steps:
                        train_writer.add_summary(result_log, train_epoch)
                    else:
                        train_writer.add_summary(result_log, epoch_)
                    sum_train_accuracy = []

                    epoch_ += 1

                if step % should_log == 0:
                    tf.logging.info(
                        'global step %s: loss = %.4f (%.3f sec/step)',
                        str(step).rjust(6, '0'), np.mean(total_loss),
                        np.mean(time_elapsed))
                    train_writer.add_summary(log, step)
                    time_elapsed = []
                    total_loss = []

                elif (step * batch_size) % dataset_len == 0:
                    train_writer.add_summary(log, step)

            ## save variables to use for something
            var = {}
            variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES) + tf.get_collection(
                    'BN_collection')
            for v in variables:
                var[v.name[:-2]] = sess.run(v)
            sio.savemat(FLAGS.train_dir + '/train_params.mat', var)

            ## close all
            tf.logging.info('Finished training! Saving model to disk.')
            train_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.STOP))
            train_writer.close()
Esempio n. 2
0
def main(_):
    ### define path and hyper-parameter
    model_name = 'ResNet'
    Learning_rate = 1e-1  # initialization methods : 1e-2, others : 1e-1

    batch_size = 128
    val_batch_size = 200
    train_epoch = 100
    init_epoch = 40 if FLAGS.Distillation == 'FitNet' or FLAGS.Distillation == 'FSP' or FLAGS.Distillation == 'AB' else 0

    total_epoch = init_epoch + train_epoch
    weight_decay = 5e-4

    should_log = 200
    save_summaries_secs = 20
    tf.logging.set_verbosity(tf.logging.INFO)
    gpu_num = '0'

    if FLAGS.Distillation == 'None':
        FLAGS.Distillation = None

    with tf.Graph().as_default() as graph:
        # make placeholder for inputs
        sz = [32, 32, 3]
        train_image = tf.placeholder(tf.uint8, [batch_size] + sz)
        train_label = tf.placeholder(tf.int32, [batch_size])

        # pre-processing
        image = pre_processing(train_image, is_training=True)
        label = tf.contrib.layers.one_hot_encoding(train_label,
                                                   100,
                                                   on_value=1.0)

        # make global step
        global_step = tf.train.create_global_step()
        decay_steps = 50000 // batch_size
        epoch = tf.floor_div(tf.cast(global_step, tf.float32), decay_steps)
        max_number_of_steps = int(50000 * total_epoch) // batch_size

        # make learning rate scheduler
        LR = learning_rate_scheduler(Learning_rate,
                                     [epoch, init_epoch, train_epoch],
                                     [0.3, 0.6, 0.8], 0.1)

        ## load Net
        class_loss, train_accuracy = MODEL(model_name,
                                           FLAGS.main_scope,
                                           weight_decay,
                                           image,
                                           label,
                                           is_training=True,
                                           reuse=False,
                                           drop=True,
                                           Distillation=FLAGS.Distillation)

        #make training operator
        train_op = op_util.Optimizer_w_Distillation(class_loss, LR, epoch,
                                                    init_epoch, global_step,
                                                    FLAGS.Distillation)

        ## collect summary ops for plotting in tensorboard
        summary_op = tf.summary.merge(tf.get_collection(
            tf.GraphKeys.SUMMARIES),
                                      name='summary_op')

        ## make clone model to validate
        val_image = tf.placeholder(tf.float32, [val_batch_size] +
                                   image.get_shape().as_list()[1:])
        val_label = tf.placeholder(tf.int32, [val_batch_size])
        val_label_onhot = tf.contrib.layers.one_hot_encoding(val_label,
                                                             100,
                                                             on_value=1.0)
        val_image_ = pre_processing(val_image, is_training=False)
        val_loss, val_accuracy = MODEL(model_name,
                                       FLAGS.main_scope,
                                       0.,
                                       val_image_,
                                       val_label_onhot,
                                       is_training=False,
                                       reuse=True,
                                       drop=False,
                                       Distillation=FLAGS.Distillation)

        ## make placeholder and summary op for training and validation results
        train_acc_place = tf.placeholder(dtype=tf.float32)
        val_acc_place = tf.placeholder(dtype=tf.float32)
        val_summary = [
            tf.summary.scalar('accuracy/training_accuracy', train_acc_place),
            tf.summary.scalar('accuracy/validation_accuracy', val_acc_place)
        ]
        val_summary_op = tf.summary.merge(list(val_summary),
                                          name='val_summary_op')

        ## start training
        step = 0
        highest = 0
        train_writer = tf.summary.FileWriter('%s' % FLAGS.train_dir,
                                             graph,
                                             flush_secs=save_summaries_secs)
        val_saver = tf.train.Saver()
        config = ConfigProto()
        config.gpu_options.visible_device_list = gpu_num
        config.gpu_options.allow_growth = True

        (train_images, train_labels), (val_images, val_labels) = load_data()
        val_itr = len(val_labels) // val_batch_size
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            if FLAGS.Distillation is not None:
                ## if Distillation is True, load and assign teacher's variables
                ## this mechanism is slower but easier to modifier than load checkpoint
                global_variables = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES)
                p = sio.loadmat(home_path +
                                '/pre_trained/%s.mat' % FLAGS.teacher)
                n = 0
                for v in global_variables:
                    if p.get(v.name[:-2]) is not None:
                        sess.run(
                            v.assign(p[v.name[:-2]].reshape(
                                *v.get_shape().as_list())))
                        n += 1
                print('%d Teacher params assigned' % n)

            sum_train_accuracy = 0
            time_elapsed = 0
            total_loss = 0
            idx = list(range(train_labels.shape[0]))
            shuffle(idx)
            for step in range(max_number_of_steps):
                start_time = time.time()
                train_batch = train_images[idx[:batch_size]]

                ## tf.__version < 1.10
                #                random_seed = [1]*(batch_size//2) + [-1]*(batch_size//2)
                #                shuffle(random_seed)
                #                train_batch = [ti if seed > 0 else np.fliplr(ti)
                #                               for seed, ti in zip(random_seed, train_batch)]

                ## feed data
                tl, log, train_acc = sess.run(
                    [train_op, summary_op, train_accuracy],
                    feed_dict={
                        train_image: train_batch,
                        train_label: np.squeeze(train_labels[idx[:batch_size]])
                    })
                time_elapsed += time.time() - start_time
                idx[:batch_size] = []
                if len(idx) < batch_size:
                    idx_ = list(range(train_labels.shape[0]))
                    shuffle(idx_)
                    idx += idx_

                total_loss += tl
                sum_train_accuracy += train_acc
                if (step % (decay_steps)
                        == 0) and (step // decay_steps >= init_epoch):
                    ## do validation
                    sum_val_accuracy = 0

                    for i in range(val_itr):
                        val_batch = val_images[i * val_batch_size:(i + 1) *
                                               val_batch_size]
                        acc = sess.run(
                            val_accuracy,
                            feed_dict={
                                val_image:
                                val_batch,
                                val_label:
                                np.squeeze(
                                    val_labels[i * val_batch_size:(i + 1) *
                                               val_batch_size])
                            })
                        sum_val_accuracy += acc

                    tf.logging.info(
                        'Epoch %s Step %s - train_Accuracy : %.2f%%  val_Accuracy : %.2f%%'
                        % (str((step) // decay_steps).rjust(3, '0'),
                           str(step).rjust(6, '0'), sum_train_accuracy * 100 /
                           decay_steps if step // decay_steps > init_epoch else
                           1., sum_val_accuracy * 100 / val_itr))

                    result_log = sess.run(
                        val_summary_op,
                        feed_dict={
                            train_acc_place:
                            sum_train_accuracy * 100 / decay_steps
                            if step // decay_steps > init_epoch else 1.,
                            val_acc_place:
                            sum_val_accuracy * 100 / val_itr,
                        })
                    if step // decay_steps == init_epoch and init_epoch > 0:
                        #re-initialize Momentum for fair comparison w/ initialization and multi-task learning methods
                        for v in global_variables:
                            if v.name[:-len('Momentum:0')] == 'Momentum:0':
                                sess.run(
                                    v.assign(
                                        np.zeros(*v.get_shape().as_list())))

                    if step == max_number_of_steps - 1:
                        train_writer.add_summary(result_log, train_epoch)
                    else:
                        train_writer.add_summary(
                            result_log, (step) // decay_steps - init_epoch)
                    sum_train_accuracy = 0
                    if sum_val_accuracy > highest:
                        highest = sum_val_accuracy
                        var = {}
                        variables = tf.get_collection(
                            tf.GraphKeys.TRAINABLE_VARIABLES
                        ) + tf.get_collection('BN_collection')
                        for v in variables:
                            var[v.name[:-2]] = sess.run(v)
                        sio.savemat(FLAGS.train_dir + '/best_params.mat', var)

                    val_saver.save(sess,
                                   "%s/best_model.ckpt" % FLAGS.train_dir)

                if (step % should_log == 0) & (step > 0):
                    tf.logging.info(
                        'global step %s: loss = %.4f (%.3f sec/step)',
                        str(step).rjust(6, '0'), total_loss / should_log,
                        time_elapsed / should_log)
                    time_elapsed = 0
                    total_loss = 0

                elif (step % (decay_steps // 2) == 0):
                    train_writer.add_summary(log, step)

            ## save variables to use for something
            var = {}
            variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES) + tf.get_collection(
                    'BN_collection')
            for v in variables:
                var[v.name[:-2]] = sess.run(v)
            sio.savemat(FLAGS.train_dir + '/train_params.mat', var)

            ## close all
            tf.logging.info('Finished training! Saving model to disk.')
            train_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.STOP))
            train_writer.close()
def main(_):
    ### Define fixed hyper-parameters
    model_name = 'Lenet5_half' if FLAGS.main_scope == 'Student' else 'Lenet5'
    Learning_rate, train_epoch = Hyper_params[FLAGS.main_scope]

    batch_size = 512
    val_batch_size = 200

    weight_decay = 1e-4
    should_log = 50
    save_summaries_secs = 20
    gpu_num = '0'

    ### Load dataset
    (train_images, train_labels), (val_images, val_labels) = load_data()

    ### Resize image size to follow the author's configuration.
    train_images = np.expand_dims(
        np.array([cv2.resize(ti, (32, 32)) for ti in train_images]), -1)
    val_images = np.expand_dims(
        np.array([cv2.resize(vi, (32, 32)) for vi in val_images]), -1)
    num_label = int(np.max(val_labels) + 1)

    if FLAGS.Distillation == 'None' or FLAGS.Distillation == None:
        ### Prevent error
        FLAGS.Distillation = None

    elif re.split('-', FLAGS.Distillation)[0] == 'Soft_logits':
        ### Sample the data at a defined rate
        data_per_label = train_labels.shape[0] // num_label
        sample_rate = int(re.split('-', FLAGS.Distillation)[1]) / 100
        idx = np.hstack([
            np.random.choice(np.where(train_labels)[0],
                             int(data_per_label * sample_rate),
                             replace=False) for i in range(num_label)
        ])
        train_images = train_images[idx]
        train_labels = train_labels[idx]
        FLAGS.Distillation = 'Soft_logits'

    elif re.split('-', FLAGS.Distillation)[0] == 'ZSKD':
        ### Load data impression for zero-shot knowledge distillation
        data = sio.loadmat(home_path + '/DI/DI-%s.mat' %
                           re.split('-', FLAGS.Distillation)[1])
        train_images = data['train_images']
        train_labels = np.expand_dims(np.argmax(data['train_labels'], 1), -1)
        '''
        if re.split('-',FLAGS.Distillation)[1] == '40':  # I implement them but not helpful for me :(
            scale_90 = np.expand_dims(np.array([np.pad(cv2.resize(i,(28,28)),[[2,2],[2,2]],'constant') for i in train_images]),-1)
            scale_75 = np.expand_dims(np.array([np.pad(cv2.resize(i,(24,24)),[[4,4],[4,4]],'constant') for i in train_images]),-1)
            scale_60 = np.expand_dims(np.array([np.pad(cv2.resize(i,(20,20)),[[6,6],[6,6]],'constant') for i in train_images]),-1)
            
            translate_left  = np.pad(train_images[:,:,6:],[[0,0],[0,0],[0,6],[0,0]],'constant')
            translate_right = np.pad(train_images[:,:,:-6],[[0,0],[0,0],[6,0],[0,0]],'constant')
            translate_up    = np.pad(train_images[:,6:,:],[[0,0],[0,6],[0,0],[0,0]],'constant')
            translate_down  = np.pad(train_images[:,:-6,:],[[0,0],[6,0],[0,0],[0,0]],'constant')
            train_images = np.vstack([train_images,
                                      scale_90, scale_75, scale_60,
                                      translate_left, translate_right, translate_up,translate_down])
            train_labels = np.vstack([train_labels]*8)
        '''

        FLAGS.Distillation = 'ZSKD'

    dataset_len, *image_size = train_images.shape
    with tf.Graph().as_default() as graph:
        ### Make placeholder
        image_ph = tf.placeholder(tf.float32, [None] + image_size)
        label_ph = tf.placeholder(tf.int32, [None])
        is_training = tf.placeholder(tf.bool, [])

        ### Pre-processing
        image = pre_processing(image_ph, is_training)
        label = tf.contrib.layers.one_hot_encoding(label_ph,
                                                   num_label,
                                                   on_value=1.0)

        ### Make global step
        global_step = tf.train.create_global_step()
        max_number_of_steps = int(dataset_len * train_epoch) // batch_size + 1

        ### Load Network
        class_loss, accuracy = MODEL(model_name,
                                     FLAGS.main_scope,
                                     weight_decay,
                                     image,
                                     label,
                                     is_training,
                                     Distillation=FLAGS.Distillation)

        ### Make training operator
        train_op = op_util.Optimizer_w_Distillation(class_loss, Learning_rate,
                                                    global_step,
                                                    FLAGS.Distillation)

        ### Collect summary ops for plotting in tensorboard
        summary_op = tf.summary.merge(tf.get_collection(
            tf.GraphKeys.SUMMARIES),
                                      name='summary_op')

        ### Make placeholder and summary op for training and validation results
        train_acc_place = tf.placeholder(dtype=tf.float32)
        val_acc_place = tf.placeholder(dtype=tf.float32)
        val_summary = [
            tf.summary.scalar('accuracy/training_accuracy', train_acc_place),
            tf.summary.scalar('accuracy/validation_accuracy', val_acc_place)
        ]
        val_summary_op = tf.summary.merge(list(val_summary),
                                          name='val_summary_op')

        ### Make a summary writer and configure GPU options
        train_writer = tf.summary.FileWriter('%s' % FLAGS.train_dir,
                                             graph,
                                             flush_secs=save_summaries_secs)
        config = ConfigProto()
        config.gpu_options.visible_device_list = gpu_num
        config.gpu_options.allow_growth = True

        val_itr = len(val_labels) // val_batch_size
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())

            if FLAGS.Distillation is not None:
                ### Load teacher network's parameters
                teacher_variables = tf.get_collection('Teacher')
                teacher = sio.loadmat(home_path +
                                      '/pre_trained/%s.mat' % FLAGS.teacher)
                n = 0
                for v in teacher_variables:
                    if teacher.get(v.name[:-2]) is not None:
                        sess.run(
                            v.assign(teacher[v.name[:-2]].reshape(
                                *v.get_shape().as_list())))
                        n += 1
                print('%d Teacher params assigned' % n)

            sum_train_accuracy = []
            time_elapsed = []
            total_loss = []
            idx = np.random.choice(dataset_len, dataset_len,
                                   replace=False).tolist()
            epoch_ = 0
            best = 0
            for step in range(max_number_of_steps):
                ### Train network
                start_time = time.time()
                if len(idx) < batch_size:
                    idx += np.random.choice(dataset_len,
                                            dataset_len,
                                            replace=False).tolist()

                tl, log, train_acc = sess.run(
                    [train_op, summary_op, accuracy],
                    feed_dict={
                        image_ph: train_images[idx[:batch_size]],
                        label_ph: np.squeeze(train_labels[idx[:batch_size]]),
                        is_training: True
                    })
                time_elapsed.append(time.time() - start_time)
                total_loss.append(tl)
                sum_train_accuracy.append(train_acc)
                idx[:batch_size] = []

                step += 1
                if (step * batch_size) // dataset_len >= epoch_:
                    ## Do validation
                    sum_val_accuracy = []
                    for i in range(val_itr):
                        val_batch = val_images[i * val_batch_size:(i + 1) *
                                               val_batch_size]
                        acc = sess.run(
                            accuracy,
                            feed_dict={
                                image_ph:
                                val_batch,
                                label_ph:
                                np.squeeze(
                                    val_labels[i * val_batch_size:(i + 1) *
                                               val_batch_size]),
                                is_training:
                                False
                            })
                        sum_val_accuracy.append(acc)

                    sum_train_accuracy = np.mean(sum_train_accuracy) * 100
                    sum_val_accuracy = np.mean(sum_val_accuracy) * 100
                    print(
                        'Epoch %s Step %s - train_Accuracy : %.2f%%  val_Accuracy : %.2f%%'
                        % (str(epoch_).rjust(3, '0'), str(step).rjust(
                            6, '0'), sum_train_accuracy, sum_val_accuracy))

                    result_log = sess.run(val_summary_op,
                                          feed_dict={
                                              train_acc_place:
                                              sum_train_accuracy,
                                              val_acc_place: sum_val_accuracy
                                          })
                    if step == max_number_of_steps:
                        train_writer.add_summary(result_log, train_epoch)
                    else:
                        train_writer.add_summary(result_log, epoch_)
                    sum_train_accuracy = []

                    if sum_val_accuracy > best:
                        var = {}
                        variables = tf.get_collection(
                            tf.GraphKeys.TRAINABLE_VARIABLES
                        ) + tf.get_collection('BN_collection')
                        for v in variables:
                            var[v.name[:-2]] = sess.run(v)
                        sio.savemat(FLAGS.train_dir + '/best_params.mat', var)
                    epoch_ += 10  # validate interval

                if step % should_log == 0:
                    ### Log when it should log
                    print('global step %s: loss = %.4f (%.3f sec/step)' %
                          (str(step).rjust(len(str(train_epoch)), '0'),
                           np.mean(total_loss), np.mean(time_elapsed)))
                    train_writer.add_summary(log, step)
                    time_elapsed = []
                    total_loss = []

                elif (step * batch_size) % dataset_len == 0:
                    train_writer.add_summary(log, step)

            ### Save variables to use for something
            var = {}
            variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES) + tf.get_collection(
                    'BN_collection')
            for v in variables:
                var[v.name[:-2]] = sess.run(v)
            sio.savemat(FLAGS.train_dir + '/train_params.mat', var)
            if FLAGS.main_scope == 'Teacher':
                sio.savemat(home_path + '/pre_trained/%s.mat' % model_name,
                            var)

            ### close all
            print('Finished training! Saving model to disk.')
            train_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.STOP))
            train_writer.close()
def main(_):
    ### define path and hyper-parameter
    Learning_rate =1e-1

    batch_size = 128
    val_batch_size = 200
    train_epoch = 200
    weight_decay = 5e-4

    should_log          = 100
    save_summaries_secs = 20
    tf.logging.set_verbosity(tf.logging.INFO)
    gpu_num = '0'

    if FLAGS.Distillation == 'None':
        FLAGS.Distillation = None
        
    train_images, train_labels, val_images, val_labels, pre_processing, teacher = Dataloader(FLAGS.dataset, home_path, FLAGS.model_name)
    num_label = int(np.max(train_labels)+1)

    rate = 1.
    if rate < 1:
        def slice(x, rate):
            return x[:int(x.shape[0]*rate)]
        idxes = np.hstack([slice(np.where(train_labels == i)[0], rate) for i in range(num_label)])
        train_images = train_images[idxes]
        train_labels = train_labels[idxes]
        
    dataset_len, *image_size = train_images.shape

    with tf.Graph().as_default() as graph:
        # make placeholder for inputs
        image_ph = tf.placeholder(tf.uint8, [None]+image_size)
        label_ph = tf.placeholder(tf.int32, [None])
        
        is_training_ph = tf.placeholder(tf.bool,[])
        
        # pre-processing
        image = pre_processing(image_ph, is_training_ph)
        label = tf.contrib.layers.one_hot_encoding(label_ph, num_label, on_value=1.0)
     
        # make global step
        global_step = tf.train.create_global_step()
        epoch = tf.floor_div(tf.cast(global_step, tf.float32)*batch_size, dataset_len)
        max_number_of_steps = int(dataset_len*train_epoch)//batch_size+1

        # make learning rate scheduler
        LR = learning_rate_scheduler(Learning_rate, [epoch, train_epoch], [0.3, 0.6, 0.8], 0.2)
        
        ## load Net
        class_loss, accuracy = MODEL(FLAGS.model_name, FLAGS.main_scope, weight_decay, image, label, is_training_ph, Distillation = FLAGS.Distillation)
        
        #make training operator
        train_op = op_util.Optimizer_w_Distillation(class_loss, LR, epoch, global_step, FLAGS.Distillation)
        
        ## collect summary ops for plotting in tensorboard
        summary_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES), name='summary_op')
        
        ## make placeholder and summary op for training and validation results
        train_acc_place = tf.placeholder(dtype=tf.float32)
        val_acc_place   = tf.placeholder(dtype=tf.float32)
        val_summary = [tf.summary.scalar('accuracy/training_accuracy',   train_acc_place),
                       tf.summary.scalar('accuracy/validation_accuracy', val_acc_place)]
        val_summary_op = tf.summary.merge(list(val_summary), name='val_summary_op')
        
        ## start training
        train_writer = tf.summary.FileWriter('%s'%FLAGS.train_dir,graph,flush_secs=save_summaries_secs)
        config = ConfigProto()
        config.gpu_options.visible_device_list = gpu_num
        config.gpu_options.allow_growth=True
        
        val_itr = len(val_labels)//val_batch_size
        logs = {'training_acc' : [], 'validation_acc' : []}
        with tf.Session(config=config) as sess:
            if FLAGS.Distillation is not None:
                global_variables  = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
                n = 0
                for v in global_variables:
                    if teacher.get(v.name[:-2]) is not None:
                        v._initial_value = tf.constant(teacher[v.name[:-2]].reshape(*v.get_shape().as_list()))
                        v._initializer_op = tf.assign(v._variable,v._initial_value,name= v.name[:-2]+'/Assign').op
                        n += 1
                print ('%d Teacher params assigned'%n)
            sess.run(tf.global_variables_initializer())
                
            sum_train_accuracy = []; time_elapsed = []; total_loss = []
            idx = list(range(train_labels.shape[0]))
            shuffle(idx)
            epoch_ = 0
            
            for step in range(max_number_of_steps):
                start_time = time.time()
                
                ## feed data
                tl, log, train_acc = sess.run([train_op, summary_op, accuracy],
                                              feed_dict = {image_ph : train_images[idx[:batch_size]],
                                                           label_ph : np.squeeze(train_labels[idx[:batch_size]]),
                                                           is_training_ph : True})
                time_elapsed.append( time.time() - start_time )
                total_loss.append(tl)
                sum_train_accuracy.append(train_acc)
                idx[:batch_size] = []
                if len(idx) < batch_size:
                    idx_ = list(range(train_labels.shape[0]))
                    shuffle(idx_)
                    idx += idx_
                
                step += 1
                if (step*batch_size)//dataset_len>=epoch_:
                    ## do validation
                    sum_val_accuracy = []
                    for i in range(val_itr):
                        acc = sess.run(accuracy, feed_dict = {image_ph : val_images[i*val_batch_size:(i+1)*val_batch_size],
                                                              label_ph : np.squeeze(val_labels[i*val_batch_size:(i+1)*val_batch_size]),
                                                              is_training_ph : False})
                        sum_val_accuracy.append(acc)
                        
                    sum_train_accuracy = np.mean(sum_train_accuracy)*100
                    sum_val_accuracy= np.mean(sum_val_accuracy)*100
                    tf.logging.info('Epoch %s Step %s - train_Accuracy : %.2f%%  val_Accuracy : %.2f%%'
                                    %(str(epoch_).rjust(3, '0'), str(step).rjust(6, '0'), 
                                    sum_train_accuracy, sum_val_accuracy))

                    result_log = sess.run(val_summary_op, feed_dict={train_acc_place : sum_train_accuracy,
                                                                     val_acc_place   : sum_val_accuracy   })
                    logs['training_acc'].append(sum_train_accuracy)
                    logs['validation_acc'].append(sum_val_accuracy)
    
                    if step == max_number_of_steps:
                        train_writer.add_summary(result_log, train_epoch)
                    else:
                        train_writer.add_summary(result_log, epoch_)
                    sum_train_accuracy = []

                    epoch_ += 1
                    
                    variables  = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)+tf.get_collection('BN_collection')
                    
                if step % should_log == 0:
                    tf.logging.info('global step %s: loss = %.4f (%.3f sec/step)',str(step).rjust(6, '0'), np.mean(total_loss), np.mean(time_elapsed))
                    train_writer.add_summary(log, step)
                    time_elapsed = []
                    total_loss = []
                
                
                elif (step*batch_size) % dataset_len == 0:
                    train_writer.add_summary(log, step)

            ## save variables to use for something
            var = {}
            variables  = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)+tf.get_collection('BN_collection')
            for v in variables:
                if v.name.split('/')[0] == FLAGS.main_scope:
                    var[v.name[:-2]] = sess.run(v)
            
            sio.savemat(FLAGS.train_dir + '/train_params.mat',var)
            sio.savemat(FLAGS.train_dir + '/log.mat',logs)
            
            ## close all
            tf.logging.info('Finished training! Saving model to disk.')
            train_writer.add_session_log(tf.SessionLog(status=tf.SessionLog.STOP))
            train_writer.close()
Esempio n. 5
0
def main(_):
    ### define path and hyper-parameter
    model_name = 'ResNet'
    Learning_rate = 1e-1

    batch_size = 128
    val_batch_size = 200
    train_epoch = 10
    init_epoch = 4 if FLAGS.Distillation in {
        'FitNet', 'FSP', 'FT', 'AB', 'MHGD'
    } else 0

    total_epoch = init_epoch + train_epoch
    weight_decay = 5e-4

    should_log = 200
    save_summaries_secs = 20
    tf.logging.set_verbosity(tf.logging.INFO)
    gpu_num = '0'

    if FLAGS.Distillation == 'None':
        FLAGS.Distillation = None

    train_images, train_labels, val_images, val_labels, pre_processing, teacher = Dataloader(
        FLAGS.dataset, home_path)
    num_label = int(np.max(train_labels) + 1)

    dataset_len, *image_size = train_images.shape

    with tf.Graph().as_default() as graph:

        # make placeholder for inputs
        image_ph = tf.placeholder(tf.uint8, [None] + image_size)
        label_ph = tf.placeholder(tf.int32, [None])
        is_training_ph = tf.placeholder(tf.bool, [])

        # pre-processing
        image = pre_processing(image_ph, is_training_ph)
        label = tf.contrib.layers.one_hot_encoding(label_ph,
                                                   num_label,
                                                   on_value=1.0)

        # make global step
        global_step = tf.train.create_global_step()
        epoch = tf.floor_div(
            tf.cast(global_step, tf.float32) * batch_size, dataset_len)
        max_number_of_steps = int(dataset_len * total_epoch) // batch_size + 1

        # make learning rate scheduler
        LR = learning_rate_scheduler(Learning_rate,
                                     [epoch, init_epoch, train_epoch],
                                     [0.3, 0.6, 0.8], 0.1)

        ## load Net
        class_loss, accuracy = MODEL(model_name,
                                     FLAGS.main_scope,
                                     weight_decay,
                                     image,
                                     label,
                                     is_training_ph,
                                     reuse=False,
                                     drop=True,
                                     Distillation=FLAGS.Distillation,
                                     hintLayerIndex=FLAGS.hintLayerIndex,
                                     guidedLayerIndex=FLAGS.guidedLayerIndex)

        # make training operator
        train_op = op_util.Optimizer_w_Distillation(class_loss, LR, epoch,
                                                    init_epoch, global_step,
                                                    FLAGS.Distillation)

        ## collect summary ops for plotting in tensorboard
        summary_op = tf.summary.merge(tf.get_collection(
            tf.GraphKeys.SUMMARIES),
                                      name='summary_op')

        ## make placeholder and summary op for training and validation results
        train_acc_place = tf.placeholder(dtype=tf.float32)
        val_acc_place = tf.placeholder(dtype=tf.float32)
        val_summary = [
            tf.summary.scalar('accuracy/training_accuracy', train_acc_place),
            tf.summary.scalar('accuracy/validation_accuracy', val_acc_place)
        ]
        val_summary_op = tf.summary.merge(list(val_summary),
                                          name='val_summary_op')

        ## start training
        train_writer = tf.summary.FileWriter('%s' % FLAGS.train_dir,
                                             graph,
                                             flush_secs=save_summaries_secs)
        config = ConfigProto()
        config.gpu_options.visible_device_list = gpu_num
        config.gpu_options.allow_growth = True

        val_itr = len(val_labels) // val_batch_size
        with tf.Session(config=config) as sess:
            if FLAGS.Distillation is not None and FLAGS.Distillation != 'DML':
                global_variables = tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES)
                n = 0
                for v in global_variables:
                    if teacher.get(v.name[:-2]) is not None:
                        v._initial_value = tf.constant(
                            teacher[v.name[:-2]].reshape(
                                *v.get_shape().as_list()))
                        v.initializer_op = tf.assign(v._variable,
                                                     v._initial_value,
                                                     name=v.name[:-2] +
                                                     'Assign').op
                        n += 1
                print('%d Teacher params assigned' % n)
            sess.run(tf.global_variables_initializer())

            sum_train_accuracy = []
            time_elapsed = []
            total_loss = []
            idx = list(range(train_labels.shape[0]))
            shuffle(idx)
            epoch_ = 0
            for step in range(max_number_of_steps):
                start_time = time.time()
                ## feed data
                tl, log, train_acc = sess.run(
                    [train_op, summary_op, accuracy],
                    feed_dict={
                        image_ph: train_images[idx[:batch_size]],
                        label_ph: np.squeeze(train_labels[idx[:batch_size]]),
                        is_training_ph: True
                    })

                time_elapsed.append(time.time() - start_time)
                total_loss.append(tl)
                sum_train_accuracy.append(train_acc)
                idx[:batch_size] = []
                if len(idx) < batch_size:
                    idx_ = list(range(train_labels.shape[0]))
                    shuffle(idx_)
                    idx += idx_

                step += 1
                if (step * batch_size) // dataset_len >= init_epoch + epoch_:
                    ## do validation
                    sum_val_accuracy = []
                    for i in range(val_itr):
                        val_batch = val_images[i * val_batch_size:(i + 1) *
                                               val_batch_size]
                        acc = sess.run(
                            accuracy,
                            feed_dict={
                                image_ph:
                                val_batch,
                                label_ph:
                                np.squeeze(
                                    val_labels[i * val_batch_size:(i + 1) *
                                               val_batch_size]),
                                is_training_ph:
                                False
                            })
                        sum_val_accuracy.append(acc)

                    sum_train_accuracy = np.mean(sum_train_accuracy) * 100 if (
                        step * batch_size) // dataset_len > init_epoch else 1.
                    sum_val_accuracy = np.mean(sum_val_accuracy) * 100
                    tf.logging.info(
                        'Epoch %s Step %s - train_Accuracy : %.2f%%  val_Accuracy : %.2f%%'
                        % (str(epoch_).rjust(3, '0'), str(step).rjust(
                            6, '0'), sum_train_accuracy, sum_val_accuracy))

                    result_log = sess.run(val_summary_op,
                                          feed_dict={
                                              train_acc_place:
                                              sum_train_accuracy,
                                              val_acc_place: sum_val_accuracy
                                          })

                    if (step * batch_size
                        ) // dataset_len == init_epoch and init_epoch > 0:
                        # re-initialize Momentum for fair comparison w/ initialization and multi-task learning methods
                        for v in global_variables:
                            if v.name[:-len('Momentum:0')] == 'Momentum:0':
                                sess.run(
                                    v.assign(
                                        np.zeros(*v.get_shape().as_list())))

                    if step == max_number_of_steps:
                        train_writer.add_summary(result_log, train_epoch)
                    else:
                        train_writer.add_summary(result_log, epoch_)
                    sum_train_accuracy = []

                    epoch_ += 1

                if step % should_log == 0:
                    tf.logging.info(
                        'global step %s: loss = %.4f (%.3f sec/step)',
                        str(step).rjust(6, '0'), np.mean(total_loss),
                        np.mean(time_elapsed))
                    train_writer.add_summary(log, step)
                    time_elapsed = []
                    total_loss = []

                elif (step * batch_size) % dataset_len == 0:
                    train_writer.add_summary(log, step)

            ## save variables to use for something
            # set the tf saver
            var = {}
            varToHin = {}
            hintLayerIndex = FLAGS.hintLayerIndex
            guidedLayerIndex = FLAGS.guidedLayerIndex
            variables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES) + tf.get_collection(
                    'BN_collection')
            for v in variables:
                if v.name.startswith("Teacher"):
                    i = i + 1
                    if i == hintLayerIndex + 1:
                        var["hintLayerNext"] = v
                    else:
                        var[v.name[:-2]] = v  # sess.run(v)

            # saving the student layer only
            for v_itm in variables:
                if v_itm.name.startswith("Student_w_FitNet"):
                    i = i + 1
                    if i == guidedLayerIndex:
                        varToHin["guidedLayerNext"] = v_itm
                    else:
                        varToHin[v_itm.name[:-2]] = v_itm

            print(var)
            print(varToHin)
            saverTeacher = tf.train.Saver(var)
            saverStudent = tf.train.Saver(varToHin)

            sio.savemat(FLAGS.train_dir + '/train_params.mat', var)

            save_path = saverTeacher.save(sess, "TeacherHint.ckpt")
            save_pathToHint = saverStudent.save(sess, "studentGuided.ckpt")
            # print("Model saved in path: %s" % save_path)
            # Save the variables to disk.
            # save_path = saver.save(sess, "/tmp/model.ckpt")
            print("Teacher saved in path: %s" % save_path)
            print("Student saved in path: %s" % save_pathToHint)
            ## close all
            tf.logging.info('Finished training! Saving model to disk.')
            train_writer.add_session_log(
                tf.SessionLog(status=tf.SessionLog.STOP))
            train_writer.close()