示例#1
0
def train_student(nb_teachers):
    """
  This function trains a student using predictions made by an ensemble of
  teachers. The student and teacher models are trained using the same
  neural network architecture.
  :param dataset: string corresponding to mnist, cifar10, or svhn
  :param nb_teachers: number of teachers (in the ensemble) to learn from
  :return: True if student training went well
  """

    # Call helper function to prepare student data using teacher predictions
    stdnt_dataset = prepare_student_data(nb_teachers, save=True)

    # Unpack the student dataset
    stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset

    if config.resnet:
        dir_path = os.path.join(config.save_model, config.dataset)
        dir_path = os.path.join(dir_path,
                                'pate_num_teacher_' + str(config.nb_teachers))
        #dir_path = os.path.join(config.save_model,'pate_'+str(config.nb_teachers))
        utils.mkdir_if_missing(dir_path)
        filename = os.path.join(dir_path, '_stndent_resnet.checkpoint.pth.tar')

    print('stdnt_label used for train', stdnt_labels.shape)
    network.train_each_teacher(config.student_epoch, stdnt_data, stdnt_labels,
                               stdnt_test_data, stdnt_test_labels, filename)

    final_preds = network.pred(stdnt_test_data, filename)

    precision = hamming_accuracy(final_preds, stdnt_test_labels, torch=False)
    print('Precision of student after training: ' + str(precision))

    return True
示例#2
0
def train_student(nb_teachers):
    """
    This function trains a student using predictions made by an ensemble of
    teachers. The student and teacher models are trained using the same
    neural network architecture.
    :param dataset: string corresponding to celeba
    :param nb_teachers: number of teachers (in the ensemble) to learn from
    :return: True if student training went well
    """
    # Call helper function to prepare student data using teacher predictions
    stdnt_dataset = prepare_student_data(nb_teachers, save=True)

    # Unpack the student dataset
    stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset
    dir_path = os.path.join(config.save_model, config.dataset)
    dir_path = os.path.join(dir_path,
                            'knn_num_neighbor_' + str(config.nb_teachers))
    utils.mkdir_if_missing(dir_path)
    if config.resnet:
        filename = os.path.join(
            dir_path,
            str(config.nb_teachers) + '_stdnt_resnet.checkpoint.pth.tar')

    print('stdnt_label used for train', stdnt_labels.shape)
    network.train_each_teacher(config.student_epoch, stdnt_data, stdnt_labels,
                               stdnt_test_data, stdnt_test_labels, filename)
    return True
示例#3
0
def train_teacher():
    """
    This function trains a teacher (teacher id) among an ensemble of nb_teachers
    models for the dataset specified.
    :param dataset: string corresponding to dataset (svhn, cifar10)
    :param nb_teachers: total number of teachers in the ensemble
    :param teacher_id: id of the teacher being trained
    :return: True if everything went well
    """
    # If working directories do not exist, create them
    #assert utils.mkdir_if_missing(config.data_dir)
    #assert utils.mkdir_if_misshing(config.train_dir)
    print("Initializing dataset {}".format(config.dataset))

    dataset = data_manager.init_img_dataset(
        root=config.data_dir, name=config.dataset,
    )

    # Load the dataset


    for i in range(0,config.nb_teachers):
        # Retrieve subset of data for this teacher

        if config.dataset == 'celeba':
            data, labels = dataset._data_partition(config.nb_teachers,i)

       
        print("Length of training data: " + str(len(data)))

        # Define teacher checkpoint filename and full path
        print('data.shape for each teacher')


        dir_path = os.path.join(config.save_model,'pate_'+config.dataset+str(config.nb_teachers))
        utils.mkdir_if_missing(dir_path)
        #filename = os.path.join(dir_path, str(config.nb_teachers) + '_teachers_' + str(i) + '_resnet.checkpoint.pth.tar')
        filename = os.path.join(dir_path, str(config.nb_teachers) + '_teachers_' + str(i) + config.arch+'.checkpoint.pth.tar')
        print('save_path for teacher{}  is {}'.format(i,filename))

        network.train_each_teacher(config.teacher_epoch,data, labels, dataset.test_data, dataset.test_label, filename)


    return True
示例#4
0
def train_tracher():
    """
    Partition the entire private (training) data into config.nb_teacher subsets and train config.nb_teacher  teacher models 
    """
    # Load the dataset
    if config.dataset == 'mnist':
        train_dataset = dataset.MNIST(root=config.data_dir, train=True, download=True)
        test_dataset = dataset.MNIST(root=config.data_dir, train=False, download=True)
        ori_train_data= [ data[0] for idx, data in enumerate(train_dataset)]
        ori_test_data = [ data[0] for idx, data in enumerate(test_dataset)]
        test_labels = test_dataset.targets
        train_labels = train_dataset.targets
    elif config.dataset =='svhn':
        train_dataset = dataset.SVHN(root=config.data_dir, split='train', download=True)
        extra_dataset = dataset.SVHN(root=config.data_dir, split='extra', download=True)
        test_dataset = dataset.SVHN(root=config.data_dir, split='test', download=True)
        ori_train_data = np.concatenate((train_dataset.data,extra_dataset.data),axis=0)
        print('ori data shape', ori_train_data.shape)
        ori_train_data = np.transpose(ori_train_data, (0, 2, 3, 1))
        print('orig data shape', ori_train_data.shape)
        #ori_train_data= [ data[0] for idx, data in enumerate(train_dataset.data)]
        #for data in extra_dataset.data:
        #    ori_train_data.append(data)
        #ori_test_data = [ data[0] for idx, data in enumerate(test_dataset.data)]
        ori_test_data = np.transpose(test_dataset.data, (0,2,3,1))
        test_labels = test_dataset.labels
        extra_labels = extra_dataset.labels
        train_labels = [ll for ll in train_dataset.labels]
        for ll in extra_labels:
            train_labels.append(ll)
    batch_len = int(len(ori_train_data)/config.nb_teachers)
    for i in range(0,1):
        dir_path = os.path.join(config.save_model,'pate_'+str(config.nb_teachers))
        utils.mkdir_if_missing(dir_path)
        filename = os.path.join(dir_path, str(config.nb_teachers) + '_teachers_' + str(i) + config.arch+'.checkpoint.pth.tar')
        print('save_path for teacher{}  is {}'.format(i,filename))
        start = i * batch_len
        end = (i + 1) * batch_len
        t_data = ori_train_data[start : end]
        t_labels = train_labels[start: end] 
        network.train_each_teacher(config.teacher_epoch, t_data, t_labels, ori_test_data, test_labels, filename)