Пример #1
0
def train_student(dataset, nb_teachers, shift_dataset,inverse_w=None, weight = True):
  """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :param weight: whether this is an importance weight sampling
  :return: True if student training went well
  """
  assert input.create_dir_if_needed(FLAGS.train_dir)

  # Call helper function to prepare student data using teacher predictions

  stdnt_data = shift_dataset['data']
  stdnt_labels = shift_dataset['pred']

  print('number for deep is {}'.format(len(stdnt_labels)))

  if FLAGS.deeper:
    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student_deeper.ckpt' #NOLINT(long-line)
  else:
    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student.ckpt'  # NOLINT(long-line)

  if FLAGS.cov_shift == True:
    """
       need to compute the weight for student
       curve weight into some bound, in case the weight is too large
    """
    weights = inverse_w
  else:
    print('len of shift data'.format(len(shift_dataset['data'])))
    weights = np.zeros(len(stdnt_data))
    print('len of weight={} len of labels= {} '.format(len(weights), len(stdnt_labels)))
    for i, x in enumerate(weights):
      weights[i] = np.float32(inverse_w[stdnt_labels[i]])

  if weight == True:
    assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path, weights= weights)
  else:
    deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)
  # Compute final checkpoint name for student (with max number of steps)
  ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
  if dataset == 'adult':
    private_data, private_labels = input.ld_adult(test_only = False, train_only= True)
  elif dataset =='mnist':
    private_data, private_labels = input.ld_mnist(test_only = False, train_only = True)
  elif dataset =="svhn":
    private_data, private_labels = input.ld_svhn(test_only=False, train_only=True)
  # Compute student label predictions on remaining chunk of test set
  teacher_preds = deep_cnn.softmax_preds(private_data, ckpt_path_final)
  student_preds =  deep_cnn.softmax_preds(stdnt_data, ckpt_path_final)
  # Compute teacher accuracy
  precision_t = metrics.accuracy(teacher_preds, private_labels)
  precision_s  = metrics.accuracy(student_preds, stdnt_labels)

  precision_true = metrics.accuracy(student_preds, shift_dataset['label'])
  print('Precision of teacher after training:{} student={} true precision for student {}'.format(precision_t, precision_s,precision_true))

  return precision_t, precision_s
Пример #2
0
def train_teacher(FLAGS, dataset, nb_teachers, teacher_id):
  """
  This function trains a teacher (teacher id) among an ensemble of nb_teachers
  models for the dataset specified.
  :param dataset: string corresponding to dataset (svhn, cifar10)
  :param nb_teachers: total number of teachers in the ensemble
  :param teacher_id: id of the teacher being trained
  :return: True if everything went well
  """
  # If working directories do not exist, create them
  assert input.create_dir_if_needed(FLAGS.data_dir)
  assert input.create_dir_if_needed(FLAGS.train_dir)

  # Load the dataset
  if dataset == 'svhn':
    train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
  elif dataset == 'cifar10':
    train_data, train_labels, test_data, test_labels = input.ld_cifar10()
  elif dataset == 'mnist':
    train_data, train_labels, test_data, test_labels = input.ld_mnist()
  else:
    print("Check value of dataset flag")
    return False
  if FLAGS.cov_shift == True:
    teacher_file_name = FLAGS.data + 'PCA_teacher' + FLAGS.dataset + '.pkl'
    student_file_name = FLAGS.data + 'PCA_student' + FLAGS.dataset + '.pkl'
    f = open(teacher_file_name, 'rb')
    train_data = pickle.load(f)
    f = open(student_file_name, 'rb')
    test_data = pickle.load(f)
  # Retrieve subset of data for this teacher
  data, labels = input.partition_dataset(train_data,
                                         train_labels,
                                         nb_teachers,
                                         teacher_id)

  print("Length of training data: " + str(len(labels)))

  # Define teacher checkpoint filename and full path
  if FLAGS.deeper:
    filename = str(nb_teachers) + 'pca_teachers_' + str(teacher_id) + '_deep.ckpt'
  else:
    filename = str(nb_teachers) + 'pca_teachers_' + str(teacher_id) + '.ckpt'
  ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename

  # Perform teacher training
  assert deep_cnn.train(data, labels, ckpt_path)

  # Append final step value to checkpoint for evaluation
  ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

  # Retrieve teacher probability estimates on the test data
  teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)

  # Compute teacher accuracy
  precision = metrics.accuracy(teacher_preds, test_labels)
  print('Precision of teacher after training: ' + str(precision))

  return True
def train_teacher(dataset, nb_teachers, teacher_id):
  """
  This function trains a teacher (teacher id) among an ensemble of nb_teachers
  models for the dataset specified.
  :param dataset: string corresponding to dataset (svhn, cifar10)
  :param nb_teachers: total number of teachers in the ensemble
  :param teacher_id: id of the teacher being trained
  :return: True if everything went well
  """
  # If working directories do not exist, create them
  assert input.create_dir_if_needed(FLAGS.data_dir)
  assert input.create_dir_if_needed(FLAGS.train_dir)

  # Load the dataset
  if dataset == 'svhn':
    train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
  elif dataset == 'cifar10':
    train_data, train_labels, test_data, test_labels = input.ld_cifar10()
  elif dataset == 'mnist':
    train_data, train_labels, test_data, test_labels = input.ld_mnist()
  else:
    print("Check value of dataset flag")
    return False
    
  # Retrieve subset of data for this teacher
  data, labels = input.partition_dataset(train_data, 
                                         train_labels, 
                                         nb_teachers, 
                                         teacher_id)

  print("Length of training data: " + str(len(labels)))

  # Define teacher checkpoint filename and full path
  if FLAGS.deeper:
    filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt'
  else:
    filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
  ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename

  # Perform teacher training
  assert deep_cnn.train(data, labels, ckpt_path)

  # Append final step value to checkpoint for evaluation
  ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

  # Retrieve teacher probability estimates on the test data
  teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)

  # Compute teacher accuracy
  precision = metrics.accuracy(teacher_preds, test_labels)
  print('Precision of teacher after training: ' + str(precision))

  return True
Пример #4
0
def train_student(dataset, nb_teachers):
    """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :return: True if student training went well
  """
    assert input.create_dir_if_needed(FLAGS.train_dir)

    # Call helper function to prepare student data using teacher predictions
    stdnt_dataset = prepare_student_data(dataset, nb_teachers, save=True)

    # Unpack the student dataset
    stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset
    print('stdnt_test_data.shape', stdnt_test_data.shape)
    if dataset == 'cifar10':
        stdnt_data = stdnt_data.reshape([-1, 32, 32, 3])
        stdnt_test_data = stdnt_test_data.reshape([-1, 32, 32, 3])
    elif dataset == 'mnist':
        stdnt_data = stdnt_data.reshape([-1, 28, 28, 1])
        stdnt_test_data = stdnt_test_data.reshape([-1, 28, 28, 1])
    elif dataset == 'svhn':
        stdnt_data = stdnt_data.reshape([-1, 32, 32, 3])
        stdnt_test_data = stdnt_test_data.reshape([-1, 32, 32, 3])
    # Prepare checkpoint filename and path
    if FLAGS.deeper:
        ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(
            nb_teachers) + '_student_deeper.ckpt'  #NOLINT(long-line)
    else:
        ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(
            nb_teachers) + '_student.ckpt'  # NOLINT(long-line)

    # Start student training
    assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)

    # Compute final checkpoint name for student (with max number of steps)
    ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

    # Compute student label predictions on remaining chunk of test set
    student_preds = deep_cnn.softmax_preds(stdnt_test_data, ckpt_path_final)

    # Compute teacher accuracy
    precision = metrics.accuracy(student_preds, stdnt_test_labels)
    print('Precision of student after training: ' + str(precision))

    return True
Пример #5
0
def start_train(train_data, train_labels, test_data, test_labels, ckpt, ckpt_final, only_rpt=False):  #
    if not only_rpt:
        assert deep_cnn.train(train_data, train_labels, ckpt)

    preds_tr = deep_cnn.softmax_preds(train_data, ckpt_final)  # 得到概率向量
    preds_ts = deep_cnn.softmax_preds(test_data, ckpt_final)

    logging.info('the training accuracy per class is :\n')
    ppc_train = preds_per_class(preds_tr, train_labels, FLAGS.P_per_class, FLAGS.P_all_classes)  # 一个list,10维
    logging.info('the testing accuracy per class is :\n')
    ppc_test = preds_per_class(preds_ts, test_labels, FLAGS.P_per_class, FLAGS.P_all_classes)  # 一个list,10维

    precision_ts = accuracy(preds_ts, test_labels)  # 算10类的总的正确率
    precision_tr = accuracy(preds_tr, train_labels)
    logging.info('Acc_tr:{:.3f}   Acc_ts: {:.3f}'.format(precision_tr, precision_ts))

    return precision_tr, precision_ts, ppc_train, ppc_test, preds_tr
Пример #6
0
def start_train(train_data, train_labels, test_data, test_labels, ckpt_path, ckpt_path_final):  #
    assert deep_cnn.train(train_data, train_labels, ckpt_path)
    print('np.max(train_data) before preds: ',np.max(train_data))

    preds_tr = deep_cnn.softmax_preds(train_data, ckpt_path_final)  # 得到概率向量
    preds_ts = deep_cnn.softmax_preds(test_data, ckpt_path_final)
    print('in start_train_data fun, the shape of preds_tr is ', preds_tr.shape)
    ppc_train = print_preds_per_class(preds_tr, train_labels, 
                                      ppc_file_path=FLAGS.P_per_class,
                                      pac_file_path=FLAGS.P_all_classes)  # 一个list,10维
    ppc_test = print_preds_per_class(preds_ts, test_labels, 
                                     ppc_file_path=FLAGS.P_per_class,
                                     pac_file_path=FLAGS.P_all_classes)  # 一个list,10维
    precision_ts = metrics.accuracy(preds_ts, test_labels)  # 算10类的总的正确率
    precision_tr = metrics.accuracy(preds_tr, train_labels)
    print('precision_tr:%.3f \nprecision_ts: %.3f' %(precision_tr, precision_ts))
    # 已经包括了训练和预测和输出结果
    return precision_tr, precision_ts, ppc_train, ppc_test, preds_tr
def start_train_data(train_data, train_labels, test_data, test_labels,
                     ckpt_path, ckpt_path_final):  #
    assert deep_cnn.train(train_data, train_labels, ckpt_path)
    preds_tr = deep_cnn.softmax_preds(train_data, ckpt_path_final)  # 得到概率向量
    preds_ts = deep_cnn.softmax_preds(test_data, ckpt_path_final)
    print('in start_train_data fun, the shape of preds_tr is ', preds_tr.shape)
    ppc_train = utils.print_preds_per_class(
        preds_tr,
        train_labels,
        ppc_file_path=FLAGS.P_per_class,
        pac_file_path=FLAGS.P_all_classes)  # 一个list,10维
    ppc_test = utils.print_preds_per_class(
        preds_ts, test_labels)  # 全体测试数据的概率向量送入函数,打印出来。计算 每一类 的正确率

    precision_ts = metrics.accuracy(preds_ts, test_labels)  # 算10类的总的正确率
    precision_tr = metrics.accuracy(preds_tr, train_labels)
    print('precision_tr:', precision_tr, 'precision_ts:', precision_ts)
    # 已经包括了训练和预测和输出结果
    return precision_tr, precision_ts, ppc_train, ppc_test, preds_tr
Пример #8
0
def train_teacher(dataset, nb_teachers, teacher_id):
    """
    训练指定ID的教师模型
    :param dataset: 数据集名称
    :param nb_teachers: 老师数量
    :param teacher_id: 老师ID
    :return:
    """
    # 如果目录不存在就创建对应的目录
    assert Input.create_dir_if_needed(FLAGS.data_dir)
    assert Input.create_dir_if_needed(FLAGS.train_dir)
    # 加载对应的数据集
    if dataset == 'mnist':
        train_data, train_labels, test_data, test_labels = Input.load_mnist()
    else:
        print("没有对应的数据集")
        return False

    # 给对应的老师分配对应的数据
    data, labels = Input.partition_dataset(train_data, train_labels,
                                           nb_teachers, teacher_id)
    print("Length of training data: " + str(len(labels)))

    filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename

    # 开始训练,并保存训练模型
    assert deep_cnn.train(data, labels, ckpt_path)

    # 拼接得到训练后的模型位置
    ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

    # 读取教师模型对测试数据进行验证
    teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)
    # 计算教师模型准确率
    precision = analysis.accuracy(teacher_preds, test_labels)
    print('Precision of teacher after training: ' + str(precision))

    return True
Пример #9
0
def train_student(dataset, nb_teachers):

    assert Input.create_dir_if_needed(FLAGS.train_dir)

    # 准备学生模型数据
    student_dataset = prepare_student_data(dataset,nb_teachers,save=True)
    # 解压学生数据
    stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = student_dataset

    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student.ckpt'
    # 训练
    assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)

    ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

    # 预测
    student_preds = deep_cnn.softmax_preds(stdnt_test_data,ckpt_path_final)

    precision = analysis.accuracy(student_preds,stdnt_test_labels)
    print('Precision of student after training: ' + str(precision))

    return True
def train_student(dataset, nb_teachers):
  """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :return: True if student training went well
  """
  assert input.create_dir_if_needed(FLAGS.train_dir)

  # Call helper function to prepare student data using teacher predictions
  stdnt_dataset = prepare_student_data(dataset, nb_teachers, save=True)

  # Unpack the student dataset
  stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset

  # Prepare checkpoint filename and path
  if FLAGS.deeper:
    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student_deeper.ckpt' #NOLINT(long-line)
  else:
    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student.ckpt'  # NOLINT(long-line)

  # Start student training
  assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)

  # Compute final checkpoint name for student (with max number of steps)
  ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

  # Compute student label predictions on remaining chunk of test set
  student_preds = deep_cnn.softmax_preds(stdnt_test_data, ckpt_path_final)

  # Compute teacher accuracy
  precision = metrics.accuracy(student_preds, stdnt_test_labels)
  print('Precision of student after training: ' + str(precision))

  return True
Пример #11
0
def train_teacher(dataset, nb_teachers, teacher_id):
    """
  This function trains a teacher (teacher id) among an ensemble of nb_teachers
  models for the dataset specified.
  :param dataset: string corresponding to dataset (svhn, cifar10)
  :param nb_teachers: total number of teachers in the ensemble
  :param teacher_id: id of the teacher being trained
  :return: True if everything went well
  """
    # If working directories do not exist, create them
    assert input.create_dir_if_needed(FLAGS.data_dir)
    assert input.create_dir_if_needed(FLAGS.train_dir)
    print("teacher {}:".format(teacher_id))
    # Load the dataset
    if dataset == 'svhn':
        train_data, train_labels, test_data, test_labels = input.ld_svhn(
            extended=True)
    elif dataset == 'cifar10':
        train_data, train_labels, test_data, test_labels = input.ld_cifar10()
    elif dataset == 'mnist':
        train_data, train_labels, test_data, test_labels = input.ld_mnist()
    else:
        print("Check value of dataset flag")
        return False

    path = os.path.abspath('.')

    path1 = path + '\\plts_nodisturb\\'

    # 对标签进行干扰
    import copy
    train_labels1 = copy.copy(train_labels)
    train_labels2 = disturb(train_labels, 0.1)
    disturb(test_labels, 0.1)
    #path1 = path + '\\plts_withdisturb\\'

    # Retrieve subset of data for this teacher
    #干扰前
    data, labels = input.partition_dataset(train_data, train_labels,
                                           nb_teachers, teacher_id)

    from pca import K_S
    import operator
    print(operator.eq(train_labels1, train_labels2))
    print("干扰前: ", K_S.tst_norm(train_labels1))
    print("干扰后: ", K_S.tst_norm(train_labels2))
    print(K_S.tst_samp(train_labels1, train_labels2))

    print("Length of training data: " + str(len(labels)))

    # Define teacher checkpoint filename and full path
    if FLAGS.deeper:
        filename = str(nb_teachers) + '_teachers_' + str(
            teacher_id) + '_deep.ckpt'
    else:
        filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename

    # Perform teacher training
    losses = deep_cnn.train(data, labels, ckpt_path)

    # Append final step value to checkpoint for evaluation
    ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)

    # Retrieve teacher probability estimates on the test data
    teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)

    # Compute teacher accuracy
    precision = metrics.accuracy(teacher_preds, test_labels)
    print('Precision of teacher after training: ' + str(precision))
    print("each n step loss: ", losses)

    #x = list(range(1, len(losses)+1))
    #plt.plot(x, losses, 'bo-', markersize=20)
    #plt.savefig(path1 + 'loss' + str(teacher_id) + '.jpg')
    #plt.show()
    #print("x: ",x)
    #print("loss: ", losses)

    return True
Пример #12
0
def train_student(dataset,
                  nb_teachers,
                  knock,
                  weight=True,
                  inverse_w=None,
                  shift_dataset=None):
    """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :return: True if student training went well
  """
    assert input.create_dir_if_needed(FLAGS.train_dir)
    print('len of shift data'.format(len(shift_dataset['data'])))
    # Call helper function to prepare student data using teacher predictions
    stdnt_data, stdnt_labels = prepare_student_data(dataset,
                                                    nb_teachers,
                                                    save=True,
                                                    shift_data=shift_dataset)

    # Unpack the student dataset, here stdnt_labels are already the ensemble noisy version
    # Prepare checkpoint filename and path
    if FLAGS.deeper:
        ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(
            nb_teachers) + '_student_deeper.ckpt'  #NOLINT(long-line)
    else:
        ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(
            nb_teachers) + '_student.ckpt'  # NOLINT(long-line)

    # Start student training
    weights = np.zeros(len(stdnt_data))
    print('len of weight={} len of labels= {} '.format(len(weights),
                                                       len(stdnt_labels)))
    for i, x in enumerate(weights):
        weights[i] = np.float32(inverse_w[stdnt_labels[i]])
    if weight == True:
        assert deep_cnn.train(stdnt_data,
                              stdnt_labels,
                              ckpt_path,
                              weights=weights)
    else:
        deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)
    # Compute final checkpoint name for student (with max number of steps)
    ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
    private_data, private_labels = input.ld_mnist(test_only=False,
                                                  train_only=True)
    # Compute student label predictions on remaining chunk of test set
    teacher_preds = deep_cnn.softmax_preds(private_data, ckpt_path_final)
    student_preds = deep_cnn.softmax_preds(stdnt_data, ckpt_path_final)
    # Compute teacher accuracy
    precision_t = metrics.accuracy(teacher_preds, private_labels)
    precision_s = metrics.accuracy(student_preds, stdnt_labels)
    if knock == True:
        print(
            'weight is {} shift_ratio={} Precision of teacher after training:{} student={}'
            .format(weight, shift_dataset['shift_ratio'], precision_t,
                    precision_s))
    else:
        print(
            'weight is {} shift_ratio={} Precision of teacher after training:{} student={}'
            .format(weight, shift_dataset['alpha'], precision_t, precision_s))

    return True
Пример #13
0
def train_student(dataset,
                  nb_teachers,
                  weight=True,
                  inverse_w=None,
                  shift_dataset=None):
    """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :param weight: whether this is an importance weight sampling
  :return: True if student training went well
  """
    assert input.create_dir_if_needed(FLAGS.train_dir)

    # Call helper function to prepare student data using teacher predictions
    if shift_dataset is not None:
        stdnt_data, stdnt_labels = prepare_student_data(
            dataset, nb_teachers, save=True, shift_data=shift_dataset)
    else:
        if FLAGS.PATE2 == True:
            keep_idx, stdnt_data, stdnt_labels = prepare_student_data(
                dataset, nb_teachers, save=True)
        else:
            stdnt_data, stdnt_labels = prepare_student_data(dataset,
                                                            nb_teachers,
                                                            save=True)
    rng = np.random.RandomState(FLAGS.dataset_seed)
    rand_ix = rng.permutation(len(stdnt_labels))
    stdnt_data = stdnt_data[rand_ix]
    stdnt_labels = stdnt_labels[rand_ix]
    print('number for deep is {}'.format(len(stdnt_labels)))
    # Unpack the student dataset, here stdnt_labels are already the ensemble noisy version
    # Prepare checkpoint filename and path
    if FLAGS.deeper:
        ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(
            nb_teachers) + '_student_deeper.ckpt'  #NOLINT(long-line)
    else:
        ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(
            nb_teachers) + '_student.ckpt'  # NOLINT(long-line)

    # Start student training
    if FLAGS.cov_shift == True:
        """
       need to compute the weight for student
       curve weight into some bound, in case the weight is too large
    """
        weights = inverse_w

        #y_s = np.expand_dims(y_s, axis=1)

    else:
        print('len of shift data'.format(len(shift_dataset['data'])))
        weights = np.zeros(len(stdnt_data))
        print('len of weight={} len of labels= {} '.format(
            len(weights), len(stdnt_labels)))
        for i, x in enumerate(weights):
            weights[i] = np.float32(inverse_w[stdnt_labels[i]])

    if weight == True:
        if FLAGS.PATE2 == True:
            assert deep_cnn.train(stdnt_data,
                                  stdnt_labels,
                                  ckpt_path,
                                  weights=weights[keep_idx])
        else:
            assert deep_cnn.train(stdnt_data,
                                  stdnt_labels,
                                  ckpt_path,
                                  weights=weights)
    else:
        deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)
    # Compute final checkpoint name for student (with max number of steps)
    ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
    if dataset == 'adult':
        private_data, private_labels = input.ld_adult(test_only=False,
                                                      train_only=True)
    elif dataset == 'mnist':
        private_data, private_labels = input.ld_mnist(test_only=False,
                                                      train_only=True)
    elif dataset == "svhn":
        private_data, private_labels = input.ld_svhn(test_only=False,
                                                     train_only=True)
    # Compute student label predictions on remaining chunk of test set
    teacher_preds = deep_cnn.softmax_preds(private_data, ckpt_path_final)
    student_preds = deep_cnn.softmax_preds(stdnt_data, ckpt_path_final)
    # Compute teacher accuracy
    precision_t = metrics.accuracy(teacher_preds, private_labels)
    precision_s = metrics.accuracy(student_preds, stdnt_labels)
    if FLAGS.cov_shift == True:
        student_file_name = FLAGS.data + 'PCA_student' + FLAGS.dataset + '.pkl'
        f = open(student_file_name, 'rb')
        test = pickle.load(f)
        if FLAGS.PATE2 == True:
            test_labels = test['label'][keep_idx]
        else:
            test_labels = test['label']
    precision_true = metrics.accuracy(student_preds, test_labels)
    print(
        'Precision of teacher after training:{} student={} true precision for student {}'
        .format(precision_t, precision_s, precision_true))

    return len(test_labels), precision_t, precision_s