Beispiel #1
0
def main():
    args = docopt(__doc__)
    lambda_ = args['<lambda>']
    temperature = args['<temperature>']
    param = getattr(cg, args['<distilled_model_id>'])(
        lambda_=float(lambda_), temperature=float(temperature))

    if param['resume_training']:
        param['exp_id'] = param['resume_exp_id']
    else:
        param['exp_id'] = args['<distilled_model_id>'] + '_l' \
                          + lambda_.replace('.', '-') + '_t' + temperature \
                          + '_' + time.strftime("%Y-%b-%d-%H-%M-%S")

    param['save_folder'] = os.path.join(param['save_path'], param['exp_id'])
    param_cumb = getattr(cg, args['<cumbersome_model_id>'])()

    # read data from file
    param['denom_const'] = 255.0
    if param['dataset_name'] == 'CIFAR10':
        input_data = read_CIFAR10(param['data_folder'])
    else:
        input_data = read_CIFAR100(param['data_folder'])
    print 'Reading data done!'

    if param['dataset_name'] != param_cumb['dataset_name']:
        raise ValueError(
            'Distilled model must use same dataset as source model')

    if param['dataset_name'] not in ['CIFAR10', 'CIFAR100']:
        raise ValueError('Unsupported dataset name!')

    # save parameters
    if not os.path.isdir(param['save_folder']):
        os.mkdir(param['save_folder'])

    with open(os.path.join(param['save_folder'], 'hyper_param.txt'), 'w') as f:
        for key, value in param.iteritems():
            f.write('{}: {}\n'.format(key, value))

    if param['model_name'] in ['hybrid_spatial', 'hybrid_sample']:
        param['num_layer_cnn'] = len(
            [xx for xx in param['num_cluster_cnn'] if xx])
        param['num_layer_mlp'] = len(
            [xx for xx in param['num_cluster_mlp'] if xx])
        param['num_cluster'] = param['num_cluster_cnn'] \
                               + param['num_cluster_mlp']
        num_layer_reg = param['num_layer_cnn'] + param['num_layer_mlp']

        param['num_layer_reg'] = num_layer_reg
        hist_label = [np.zeros(xx) if xx is not None else None for xx in
                      param[
                          'num_cluster']]
        reg_val = np.zeros(num_layer_reg)

    # build cumbersome model
    if param_cumb['model_name'] == 'baseline':
        cumb_model_ops = baseline_model(param_cumb)
    elif param_cumb['model_name'] == 'parsimonious':
        cumb_model_ops = clustering_model(param_cumb)
    else:
        raise ValueError('Unsupported cumbersome model')
    cumb_op_names = ['logits']
    cumb_ops = [cumb_model_ops[i] for i in cumb_op_names]
    cumb_vars = tf.global_variables()
    print 'Rebuilding cumbersome model done!'

    # restore session of cumbersome model
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    saver_cumb = tf.train.Saver(var_list=cumb_vars)
    saver_cumb.restore(sess, os.path.join(
        param_cumb['test_folder'], param_cumb['test_model_name']))
    print 'Restoring cumbersome model done!'

    # build distilled model
    if param['model_name'] == 'distilled':
        with tf.variable_scope('dist') as dist_var_scope:
            model_ops = distilled_model(param)

        # initiate session for new distilled model
        var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                      scope='dist')
        sess.run(tf.variables_initializer(var_list))
    elif param['model_name'] in ['hybrid_spatial', 'hybrid_sample']:
        with tf.variable_scope('hybrid') as hybrid_var_scope:
            model_ops = hybrid_model(param)
        var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                      scope='hybrid')
        sess.run(tf.variables_initializer(var_list))

    saver = tf.train.Saver(var_list=var_list)

    train_op_names = ['train_step', 'loss']
    val_op_names = ['scaled_logits']
    train_ops = [model_ops[i] for i in train_op_names]
    val_ops = [model_ops[i] for i in val_op_names]
    print 'Building new model done!\n'

    num_train_img = input_data['train_img'].shape[0]
    num_val_img = input_data['test_img'].shape[0]
    epoch_iter = int(math.ceil(num_train_img / param['bat_size']))
    max_val_iter = int(math.ceil(num_val_img / param['bat_size']))
    train_iterator = MiniBatchIterator(
        idx_start=0, bat_size=param['bat_size'], num_sample=num_train_img,
        train_phase=True, is_permute=True)
    val_iterator = MiniBatchIterator(
        idx_start=0, bat_size=param['bat_size'], num_sample=num_val_img,
        train_phase=False, is_permute=False)

    train_iter_start = 0

    for train_iter in xrange(train_iter_start, param['max_train_iter']):
        # generate a batch
        idx_train_bat = train_iterator.get_batch()

        bat_imgs = (input_data['train_img'][idx_train_bat, :, :, :].astype(
            np.float32) - input_data['mean_img']) / param['denom_const']
        bat_labels = input_data['train_label'][idx_train_bat].astype(np.int32)

        feed_data = {
            cumb_model_ops['input_images']: bat_imgs,
            cumb_model_ops['input_labels']: bat_labels
        }

        # get logits from cumbersome model
        source_model_logits = sess.run(
            cumb_model_ops['logits'], feed_dict=feed_data)

        feed_data = {
            model_ops['input_images']: bat_imgs,
            model_ops['input_labels']: bat_labels,
            model_ops['source_model_logits']: source_model_logits
        }

        # with tf.variable_scope(dist_var_scope):
        if param['model_name'] == 'distilled':
            results = sess.run(train_ops, feed_dict=feed_data)

            train_results = {}
            for res, name in zip(results, train_op_names):
                train_results[name] = res

            loss = train_results['loss']

        elif param['model_name'] in ['hybrid_spatial', 'hybrid_sample']:
            feed_data[model_ops['input_eta']] = param['eta']

            # deal with drifted clusters
            if (train_iter + 1) % epoch_iter == 0:
                update_cluster_centers(
                    sess, input_data, model_ops, hist_label, train_iterator,
                    param)

            # get CE/Reg values
            results = sess.run([model_ops['loss']] + model_ops['reg_ops'] +
                               model_ops['cluster_label'], feed_dict=feed_data)
            loss = results[0]
            for ii in xrange(num_layer_reg):
                reg_val[ii] = results[1 + ii]

            cluster_label = results[1 + num_layer_reg:]

            cluster_idx = 0
            for ii, xx in enumerate(param['num_cluster']):
                if xx:
                    tmp_label = cluster_label[cluster_idx]

                    for jj in xrange(tmp_label.shape[0]):
                        hist_label[ii][tmp_label[jj]] += 1

                    cluster_idx += 1

            # run clustering
            if (train_iter + 1) % 1 == 0:
                for iter_clustering in xrange(param['clustering_iter']):
                    sess.run(model_ops['clustering_ops'], feed_dict=feed_data)

            if (train_iter + 1) % epoch_iter == 0:
                for ii in xrange(len(hist_label)):
                    if hist_label[ii] is not None:
                        hist_label[ii].fill(0)

            # run optimization
            sess.run(model_ops['train_step'], feed_dict=feed_data)


        # display statistic
        if (train_iter + 1) % param['disp_iter'] == 0 or train_iter == 0:
            disp_str = 'Train Step = {:06d} || CE loss = {:e}'.format(
                train_iter + 1, loss)

            if param['model_name'] in ['hybrid_spatial', 'hybrid_sample']:
                disp_str += ' || Clustering '
                for ii in xrange(num_layer_reg):
                    disp_str += 'Reg_{:d} = {:e} '.format(ii + 1, reg_val[ii])

            print disp_str

        # valid model
        if (train_iter + 1) % param['valid_iter'] == 0 or train_iter == 0:
            num_correct = 0.0

            if param['resume_training'] == True:
                print 'Resume Exp ID = {}'.format(param['exp_id'])
            else:
                print 'Exp ID = {}'.format(param['exp_id'])

            for val_iter in xrange(max_val_iter):
                idx_val_bat = val_iterator.get_batch()

                bat_imgs = (input_data['test_img'][idx_val_bat, :, :, :].astype(
                    np.float32) - input_data['mean_img']) / param['denom_const']
                bat_labels = input_data['test_label'][
                    idx_val_bat].astype(np.int32)

                feed_data[model_ops['input_images']] = bat_imgs
                feed_data[model_ops['input_labels']] = bat_labels

                results = sess.run(val_ops, feed_dict=feed_data)

                val_results = {}
                for res, name in zip(results, val_op_names):
                    val_results[name] = res

                pred_label = np.argmax(val_results['scaled_logits'], axis=1)
                num_correct += np.sum(np.equal(pred_label,
                                               bat_labels).astype(np.float32))

            val_acc = (num_correct / num_val_img)
            print "Val accuracy = {:3f}".format(val_acc * 100)

        # snapshot a model
        if (train_iter + 1) % param['save_iter'] == 0:
            saver.save(sess, os.path.join(param['save_folder'], '{}_snapshot_{:07d}.ckpt'.format(
                param['model_name'], train_iter + 1)))
def main():
    # get exp parameters
    args = docopt(__doc__)
    param = getattr(cg, args['<exp_id>'])()

    if param['resume_training'] == True:
        param['exp_id'] = param['resume_exp_id']
    else:
        param['exp_id'] = args['<exp_id>'] + '_' + \
            time.strftime("%Y-%b-%d-%H-%M-%S")

    param['save_folder'] = os.path.join(param['save_path'], param['exp_id'])

    # save parameters
    if not os.path.isdir(param['save_folder']):
        os.mkdir(param['save_folder'])

    with open(os.path.join(param['save_folder'], 'hyper_param.txt'), 'w') as f:
        for key, value in param.iteritems():
            f.write('{}: {}\n'.format(key, value))

    if param['model_name'] == 'parsimonious':
        param['num_layer_cnn'] = len(
            [xx for xx in param['num_cluster_cnn'] if xx])
        param['num_layer_mlp'] = len(
            [xx for xx in param['num_cluster_mlp'] if xx])
        param['num_cluster'] = param['num_cluster_cnn'] + param[
            'num_cluster_mlp']
        num_layer_reg = param['num_layer_cnn'] + param['num_layer_mlp']

        param['num_layer_reg'] = num_layer_reg
        hist_label = [
            np.zeros(xx) if xx is not None else None
            for xx in param['num_cluster']
        ]
        reg_val = np.zeros(num_layer_reg)

    # read data from file
    if param['dataset_name'] not in ['CIFAR10', 'CIFAR100']:
        raise ValueError('Unsupported dataset name!')

    param['denom_const'] = 255.0
    if param['dataset_name'] == 'CIFAR10':
        input_data = read_CIFAR10(param['data_folder'])
    else:
        input_data = read_CIFAR100(param['data_folder'])

    print 'Reading data done!'

    # build model
    if param['model_name'] == 'baseline':
        model_ops = baseline_model(param)
    elif param['model_name'] == 'parsimonious':
        model_ops = clustering_model(param)
    else:
        raise ValueError('Unsupported model name!')

    train_op_names = ['train_step', 'CE_loss']
    val_op_names = ['scaled_logits']
    train_ops = [model_ops[i] for i in train_op_names]
    val_ops = [model_ops[i] for i in val_op_names]
    print 'Building model done!'

    # run model
    if param['merge_valid']:
        input_data['train_img'] = np.concatenate(
            [input_data['train_img'], input_data['val_img']], axis=0)
        input_data['train_label'] = np.concatenate(
            [input_data['train_label'], input_data['val_label']])

    num_train_img = input_data['train_img'].shape[0]
    num_val_img = input_data['test_img'].shape[0]
    epoch_iter = int(math.ceil(num_train_img / param['bat_size']))
    max_val_iter = int(math.ceil(num_val_img / param['bat_size']))
    train_iterator = MiniBatchIterator(idx_start=0,
                                       bat_size=param['bat_size'],
                                       num_sample=num_train_img,
                                       train_phase=True,
                                       is_permute=True)
    val_iterator = MiniBatchIterator(idx_start=0,
                                     bat_size=param['bat_size'],
                                     num_sample=num_val_img,
                                     train_phase=False,
                                     is_permute=False)

    saver = tf.train.Saver()
    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)

    train_iter_start = 0
    if param['resume_training']:
        saver.restore(
            sess, os.path.join(param['save_folder'],
                               param['resume_model_name']))
        train_iter_start = param['resume_step']
    else:
        sess.run(tf.initialize_all_variables())

    print 'Graph initialization done!'

    for train_iter in xrange(train_iter_start, param['max_train_iter']):
        # generate a batch
        idx_train_bat = train_iterator.get_batch()

        bat_imgs = (input_data['train_img'][idx_train_bat, :, :, :].astype(
            np.float32) - input_data['mean_img']) / param['denom_const']
        bat_labels = input_data['train_label'][idx_train_bat].astype(np.int32)

        feed_data = {
            model_ops['input_images']: bat_imgs,
            model_ops['input_labels']: bat_labels
        }

        # run a batch
        if param['model_name'] == 'baseline':
            results = sess.run(train_ops, feed_dict=feed_data)

            train_results = {}
            for res, name in zip(results, train_op_names):
                train_results[name] = res

            CE_loss = train_results['CE_loss']

        elif param['model_name'] == 'parsimonious':
            feed_data[model_ops['input_eta']] = param['eta']

            # deal with drifted clusters
            if (train_iter + 1) % epoch_iter == 0:
                update_cluster_centers(sess, input_data, model_ops, hist_label,
                                       train_iterator, param)

            # get CE/Reg values
            results = sess.run([model_ops['CE_loss']] + model_ops['reg_ops'] +
                               model_ops['cluster_label'],
                               feed_dict=feed_data)
            CE_loss = results[0]
            for ii in xrange(num_layer_reg):
                reg_val[ii] = results[1 + ii]

            cluster_label = results[1 + num_layer_reg:]

            cluster_idx = 0
            for ii, xx in enumerate(param['num_cluster']):
                if xx:
                    tmp_label = cluster_label[cluster_idx]

                    for jj in xrange(tmp_label.shape[0]):
                        hist_label[ii][tmp_label[jj]] += 1

                    cluster_idx += 1

            # run clustering
            if (train_iter + 1) % 1 == 0:
                for iter_clustering in xrange(param['clustering_iter']):
                    sess.run(model_ops['clustering_ops'], feed_dict=feed_data)

            if (train_iter + 1) % epoch_iter == 0:
                for ii in xrange(len(hist_label)):
                    if hist_label[ii] is not None:
                        hist_label[ii].fill(0)

            # run optimization
            sess.run(model_ops['train_step'], feed_dict=feed_data)

        # display statistic
        if (train_iter + 1) % param['disp_iter'] == 0 or train_iter == 0:
            disp_str = 'Train Step = {:06d} || CE loss = {:e}'.format(
                train_iter + 1, CE_loss)

            if param['model_name'] == 'parsimonious':
                disp_str += ' || Clustering '
                for ii in xrange(num_layer_reg):
                    disp_str += 'Reg_{:d} = {:e} '.format(ii + 1, reg_val[ii])

            print disp_str

        # valid model
        if (train_iter + 1) % param['valid_iter'] == 0 or train_iter == 0:
            num_correct = 0.0

            if param['resume_training'] == True:
                print 'Resume Exp ID = {}'.format(param['exp_id'])
            else:
                print 'Exp ID = {}'.format(param['exp_id'])

            for val_iter in xrange(max_val_iter):
                idx_val_bat = val_iterator.get_batch()

                bat_imgs = (input_data['test_img'][
                    idx_val_bat, :, :, :].astype(np.float32) -
                            input_data['mean_img']) / param['denom_const']
                bat_labels = input_data['test_label'][idx_val_bat].astype(
                    np.int32)

                feed_data[model_ops['input_images']] = bat_imgs
                feed_data[model_ops['input_labels']] = bat_labels

                results = sess.run(val_ops, feed_dict=feed_data)

                val_results = {}
                for res, name in zip(results, val_op_names):
                    val_results[name] = res

                pred_label = np.argmax(val_results['scaled_logits'], axis=1)
                num_correct += np.sum(
                    np.equal(pred_label, bat_labels).astype(np.float32))

            val_acc = (num_correct / num_val_img)
            print "Val accuracy = {:3f}".format(val_acc * 100)

        # snapshot a model
        if (train_iter + 1) % param['save_iter'] == 0:
            saver.save(
                sess,
                os.path.join(
                    param['save_folder'],
                    '{}_snapshot_{:07d}.ckpt'.format(param['model_name'],
                                                     train_iter + 1)))

    sess.close()
def main():
    # get exp parameters
    args = docopt(__doc__)
    param = getattr(cg, args['<exp_id>'])()

    # read data from file
    if param['dataset_name'] == 'CIFAR10':
        input_data = read_CIFAR10(param['data_folder'])
    elif param['dataset_name'] == 'CIFAR100':
        input_data = read_CIFAR100(param['data_folder'])
    else:
        raise ValueError('Unsupported dataset name!')
    print 'Reading data done!'

    # build model
    test_op_names = ['embeddings']

    if param['model_name'] == 'baseline':
        model_ops = baseline_model(param)
    elif param['model_name'] == 'parsimonious':
        model_ops = clustering_model(param)
    elif param['model_name'] == 'distilled':
        with tf.variable_scope('dist') as dist_var_scope:
            model_ops = distilled_model(param)
    elif param['model_name'] in ['hybrid_spatial', 'hybrid_sample']:
        with tf.variable_scope('hybrid') as dist_var_scope:
            model_ops = hybrid_model(param)
    else:
        raise ValueError('Unsupported model name!')

    test_ops = [model_ops[i] for i in test_op_names]
    print 'Building model done!'

    # run model
    input_data['train_img'] = np.concatenate(
        [input_data['train_img'], input_data['val_img']], axis=0)
    input_data['train_label'] = np.concatenate(
        [input_data['train_label'], input_data['val_label']])

    num_train_img = input_data['train_img'].shape[0]
    max_test_iter = int(math.ceil(num_train_img / param['bat_size']))
    test_iterator = MiniBatchIterator(idx_start=0,
                                      bat_size=param['bat_size'],
                                      num_sample=num_train_img,
                                      train_phase=False,
                                      is_permute=False)

    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    saver = tf.train.Saver()
    saver.restore(sess,
                  os.path.join(param['test_folder'], param['test_model_name']))
    print 'Graph initialization done!'

    if param['model_name'] == 'parsimonious':
        param['num_layer_cnn'] = len(
            [xx for xx in param['num_cluster_cnn'] if xx])
        param['num_layer_mlp'] = len(
            [xx for xx in param['num_cluster_mlp'] if xx])
        num_layer_reg = param['num_layer_cnn'] + param['num_layer_mlp']

        cluster_center = sess.run(model_ops['cluster_center'])

    else:
        num_layer_cnn = len(param['act_func_cnn'])
        num_layer_mlp = len(param['act_func_mlp'])
        num_layer_reg = num_layer_cnn + num_layer_mlp
        cluster_center = [None] * num_layer_reg

    embeddings = [[] for _ in xrange(num_layer_reg)]

    for test_iter in xrange(max_test_iter):
        idx_bat = test_iterator.get_batch()

        bat_imgs = (input_data['train_img'][idx_bat, :, :, :].astype(
            np.float32) - input_data['mean_img']) / 255.0

        feed_data = {model_ops['input_images']: bat_imgs}

        results = sess.run(test_ops, feed_dict=feed_data)

        test_results = {}
        for res, name in zip(results, test_op_names):
            test_results[name] = res

        for ii, ee in enumerate(test_results['embeddings']):
            if ii < 3:
                continue

            embeddings[ii] += [ee]

    for ii in xrange(num_layer_reg):
        if ii < 3:
            continue

        embeddings[ii] = np.concatenate(embeddings[ii], axis=0)

        # kmeans
        centroid, tmp_label = Kmeans(embeddings[ii], 100)
        cluster_center[ii] = centroid

        # deep clustering
        pdist = distance.cdist(cluster_center[ii], embeddings[ii],
                               'sqeuclidean')

        tmp_label = np.argsort(pdist, axis=0)[0]
        sort_idx = np.argsort(pdist, axis=1)

        NMI = compute_normalized_mutual_information(
            tmp_label, input_data['train_label'].astype(np.int),
            cluster_center[ii].shape[0], param['label_size'])
        print 'NMI = {}'.format(NMI)

    sess.close()
def main():
    # get exp parameters
    args = docopt(__doc__)
    param = getattr(cg, args['<exp_id>'])()

    # read data from file
    param['denom_const'] = 255.0
    if param['dataset_name'] == 'CIFAR10':
        input_data = read_CIFAR10(param['data_folder'])
    elif param['dataset_name'] == 'CIFAR100':
        input_data = read_CIFAR100(param['data_folder'])
    else:
        raise ValueError('Unsupported dataset name!')
    print 'Reading data done!'

    # build model
    test_op_names = ['scaled_logits']

    # build model
    if param['dataset_name'] not in ['CIFAR10', 'CIFAR100']:
        raise ValueError('Unsupported dataset name!')

    if param['model_name'] == 'baseline':
        model_ops = baseline_model(param)
    elif param['model_name'] == 'parsimonious':
        model_ops = clustering_model(param)
    elif param['model_name'] == 'distilled':
        with tf.variable_scope('dist') as dist_var_scope:
            model_ops = distilled_model(param)
    elif param['model_name'] in ['hybrid_spatial', 'hybrid_sample']:
        with tf.variable_scope('hybrid') as dist_var_scope:
            model_ops = hybrid_model(param)
    else:
        raise ValueError('Unsupported model name!')

    test_ops = [model_ops[i] for i in test_op_names]
    print 'Building model done!'

    # run model
    num_test_img = input_data['test_img'].shape[0]
    max_test_iter = int(math.ceil(num_test_img / param['bat_size']))
    test_iterator = MiniBatchIterator(idx_start=0,
                                      bat_size=param['bat_size'],
                                      num_sample=num_test_img,
                                      train_phase=False,
                                      is_permute=False)

    config = tf.ConfigProto(allow_soft_placement=True)
    sess = tf.Session(config=config)
    saver = tf.train.Saver()
    saver.restore(sess,
                  os.path.join(param['test_folder'], param['test_model_name']))
    print 'Graph initialization done!'

    num_correct = 0.0
    for val_iter in xrange(max_test_iter):
        idx_bat = test_iterator.get_batch()

        bat_imgs = (input_data['test_img'][idx_bat, :, :, :].astype(np.float32)
                    - input_data['mean_img']) / param['denom_const']
        bat_labels = input_data['test_label'][idx_bat].astype(np.int32)

        feed_data = {
            model_ops['input_images']: bat_imgs,
            model_ops['input_labels']: bat_labels
        }

        results = sess.run(test_ops, feed_dict=feed_data)

        test_results = {}
        for res, name in zip(results, test_op_names):
            test_results[name] = res

        pred_label = np.argmax(test_results['scaled_logits'], axis=1)
        num_correct += np.sum(np.equal(pred_label, bat_labels).astype(float))

    test_acc = (num_correct / num_test_img)
    print 'Test accuracy = {:.3f}'.format(test_acc * 100)

    sess.close()