示例#1
0
def train(args):

    json_options = json_file_to_pyobj(args.config)
    no_teacher_configurations = json_options.training

    wrn_depth = no_teacher_configurations.wrn_depth
    wrn_width = no_teacher_configurations.wrn_width

    M = no_teacher_configurations.M

    dataset = no_teacher_configurations.dataset
    seeds = [int(seed) for seed in no_teacher_configurations.seeds]
    log = True if no_teacher_configurations.log.lower() == 'True' else False

    if log:
        net_str = "WideResNet-{}-{}".format(wrn_depth, wrn_width)
        logfile = "No_Teacher-{}-{}-M-{}.txt".format(
            net_str, no_teacher_configurations.dataset, M)
        with open(os.path.join('./', logfile), "w") as temp:
            temp.write('No teacher {} in {} with M={}\n'.format(
                net_str, no_teacher_configurations.dataset, M))
    else:
        logfile = ''

    checkpoint = bool(no_teacher_configurations.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            # Full data
            if M == 5000:
                from utils import cifar10loaders
                loaders = cifar10loaders()
            # No data
            elif M == 0:
                from utils import cifar10loaders
                _, test_loader = cifar10loaders()
            else:
                from utils import cifar10loadersM
                loaders = cifar10loadersM(M)

        elif dataset.lower() == 'svhn':

            # Full data
            if M == 5000:
                from utils import svhnLoaders
                loaders = svhnLoaders()
            # No data
            elif M == 0:
                from utils import svhnLoaders
                _, test_loader = svhnLoaders()
            else:
                from utils import svhnloadersM
                loaders = svhnloadersM(M)

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]

        net = WideResNet(d=wrn_depth,
                         k=wrn_width,
                         n_classes=10,
                         input_features=3,
                         output_features=16,
                         strides=strides)
        net = net.to(device)

        checkpointFile = 'No_teacher_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth, wrn_width, M, seed, dataset) if checkpoint else ''

        best_test_set_accuracy = _train_seed_no_teacher(
            net, M, loaders, device, dataset, log, checkpoint, logfile,
            checkpointFile)

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(os.path.join('./', logfile), "a") as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args):
    json_options = json_file_to_pyobj(args.config)
    kd_att_configurations = json_options.training

    wrn_depth_teacher = kd_att_configurations.wrn_depth_teacher
    wrn_width_teacher = kd_att_configurations.wrn_width_teacher
    wrn_depth_student = kd_att_configurations.wrn_depth_student
    wrn_width_student = kd_att_configurations.wrn_width_student

    M = kd_att_configurations.M

    dataset = kd_att_configurations.dataset
    seeds = [int(seed) for seed in kd_att_configurations.seeds]
    log = True if kd_att_configurations.log.lower() == 'True' else False

    if log:
        teacher_str = "WideResNet-{}-{}".format(wrn_depth_teacher,
                                                wrn_width_teacher)
        student_str = "WideResNet-{}-{}".format(wrn_depth_student,
                                                wrn_width_student)
        logfile = "Teacher-{}-Student-{}-{}-M-{}-seeds-1-2.txt".format(
            teacher_str, student_str, kd_att_configurations.dataset, M)
        print(logfile)
        with open(os.path.join('./', logfile), "w") as temp:
            temp.write(
                'KD_ATT with teacher {} and student {} in {} with M={}\n'.
                format(teacher_str, student_str, kd_att_configurations.dataset,
                       M))
    else:
        logfile = ''

    checkpoint = bool(kd_att_configurations.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            # Full data
            if M == 5000:
                from utils import cifar10loaders
                loaders = cifar10loaders()
            # No data
            elif M == 0:
                from utils import cifar10loaders
                _, test_loader = cifar10loaders()
            else:
                from utils import cifar10loadersM
                loaders = cifar10loadersM(M)

        elif dataset.lower() == 'svhn':

            # Full data
            if M == 5000:
                from utils import svhnLoaders
                loaders = svhnLoaders()
            # No data
            elif M == 0:
                from utils import svhnLoaders
                _, test_loader = svhnLoaders()
            else:
                from utils import svhnloadersM
                loaders = svhnloadersM(M)

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]

        teacher_net = WideResNet(d=wrn_depth_teacher,
                                 k=wrn_width_teacher,
                                 n_classes=10,
                                 input_features=3,
                                 output_features=16,
                                 strides=strides)
        teacher_net = teacher_net.to(device)
        if dataset.lower() == 'cifar10':
            torch_checkpoint = torch.load(
                './PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher, seed),
                map_location=device)
        else:
            torch_checkpoint = torch.load(
                './PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher, seed),
                map_location=device)

        teacher_net.load_state_dict(torch_checkpoint)

        student_net = WideResNet(d=wrn_depth_student,
                                 k=wrn_width_student,
                                 n_classes=10,
                                 input_features=3,
                                 output_features=16,
                                 strides=strides)
        student_net = student_net.to(device)

        checkpointFile = 'kd_att_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth_teacher, wrn_width_teacher, wrn_depth_student,
            wrn_width_student, M, seed, dataset) if checkpoint else ''
        if M != 0:

            best_test_set_accuracy = _train_seed_kd_att(
                teacher_net, student_net, M, loaders, device, dataset, log,
                checkpoint, logfile, checkpointFile)

            if log:
                with open(os.path.join('./', logfile), "a") as temp:
                    temp.write(
                        'Best test set accuracy of seed {} is {}\n'.format(
                            seed, best_test_set_accuracy))

            test_set_accuracies.append(best_test_set_accuracy)

            if log:
                with open(os.path.join('./', logfile), "a") as temp:
                    temp.write(
                        'Best test set accuracy of seed {} is {}\n'.format(
                            seed, best_test_set_accuracy))

        else:

            best_test_set_accuracy = _test_set_eval(student_net, device,
                                                    test_loader)
            test_set_accuracies.append(best_test_set_accuracy)

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(os.path.join('./', logfile), "a") as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
示例#3
0
def train(args):
    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training

    wrn_depth = training_configurations.wrn_depth
    wrn_width = training_configurations.wrn_width
    dataset = training_configurations.dataset.lower()
    seeds = [int(seed) for seed in training_configurations.seeds]
    log = bool(training_configurations.checkpoint)

    if log:
        logfile = training_configurations.logfile
        with open(logfile, 'w') as temp:
            temp.write('WideResNet-{}-{} scratch in {}\n'.format(
                wrn_depth, wrn_width, training_configurations.dataset))
    else:
        logfile = ''

    checkpoint = bool(training_configurations.checkpoint)

    test_set_accuracies = []

    for seed in seeds:

        if dataset == 'cifar10':

            from tf_utils import cifar10loaders
            loaders = cifar10loaders(seed=seed)

        elif dataset == 'svhn':

            from utils import svhnLoaders
            loaders = svhnLoaders()
        else:
            ValueError('Datasets to choose from: CIFAR10 and SVHN')

        set_seed(seed)

        if log:
            with open(logfile, 'a') as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]
        model = WideResNet(d=wrn_depth,
                           k=wrn_width,
                           n_classes=10,
                           output_features=16,
                           strides=strides)

        checkpointFile = '_wrn-{}-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth, wrn_width, dataset, seed) if checkpoint else ''
        best_test_set_accuracy = _train_seed(model, loaders, log, checkpoint,
                                             logfile, checkpointFile)
        # best_test_set_accuracy = _train_seed_amateur(model, loaders, log, checkpoint, logfile, checkpointFile)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(logfile, 'a') as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args):

    json_options = json_file_to_pyobj(args.config)
    modified_zero_shot_configurations = json_options.training

    wrn_depth_teacher = modified_zero_shot_configurations.wrn_depth_teacher
    wrn_width_teacher = modified_zero_shot_configurations.wrn_width_teacher
    wrn_depth_student = modified_zero_shot_configurations.wrn_depth_student
    wrn_width_student = modified_zero_shot_configurations.wrn_width_student

    M = modified_zero_shot_configurations.M

    dataset = modified_zero_shot_configurations.dataset
    seeds = [int(seed) for seed in modified_zero_shot_configurations.seeds]
    log = True if modified_zero_shot_configurations.log.lower() == 'True' else False

    if log:
        teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher, wrn_width_teacher)
        student_str = 'WideResNet-{}-{}'.format(wrn_depth_student, wrn_width_student)
        logfile = 'Teacher-{}-Student-{}-{}-M-{}-Zero-Shot.txt'.format(teacher_str, student_str, modified_zero_shot_configurations.dataset, M)
        with open(logfile, 'w') as temp:
            temp.write('Zero-Shot with teacher {} and student {} in {} with M-{}\n'.format(teacher_str, student_str, modified_zero_shot_configurations.dataset, M))
    else:
        logfile = ''

    checkpoint = bool(modified_zero_shot_configurations.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            from utils import cifar10loaders
            _, test_loader = cifar10loaders()

        elif dataset.lower() == 'svhn':

            from utils import svhnLoaders
            _, test_loader = svhnLoaders()

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(logfile, 'a') as temp:
                temp.write('------------------- SEED {} -------------------\n'.format(seed))

        strides = [1, 1, 2, 2]

        teacher_net = WideResNet(d=wrn_depth_teacher, k=wrn_width_teacher, n_classes=10, input_features=3, output_features=16, strides=strides)
        teacher_net = teacher_net.to(device)
        if dataset.lower() == 'cifar10':
            torch_checkpoint = torch.load('./PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device)
        elif dataset.lower() == 'svhn':
            torch_checkpoint = torch.load('./PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device)
        else:
            raise ValueError('Dataset not found')
        teacher_net.load_state_dict(torch_checkpoint)

        student_net = WideResNet(d=wrn_depth_student, k=wrn_width_student, n_classes=10, input_features=3, output_features=16, strides=strides)
        student_net = student_net.to(device)

        generator_net = Generator()
        generator_net = generator_net.to(device)

        checkpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else ''
        finalCheckpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-final-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else ''
        genCheckpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-generator-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else ''

        best_test_set_accuracy = _train_seed_zero_shot(teacher_net, student_net, generator_net, test_loader, device, log, checkpoint, logfile, checkpointFile, finalCheckpointFile, genCheckpointFile)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(logfile, 'a') as temp:
            temp.write('Mean test set accuracy is {} with standard deviation equal to {}\n'.format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args):
    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training

    wrn_depth = training_configurations.wrn_depth
    wrn_width = training_configurations.wrn_width
    dataset = training_configurations.dataset.lower()
    seeds = [int(seed) for seed in training_configurations.seeds]
    log = True if training_configurations.log.lower() == 'True' else False

    if log:
        logfile = training_configurations.logfile
        with open(logfile, 'w') as temp:
            temp.write('WideResNet-{}-{} scratch in {}\n'.format(
                wrn_depth, wrn_width, training_configurations.dataset))
    else:
        logfile = ''

    checkpoint = True if training_configurations.checkpoint.lower(
    ) == 'true' else False

    if dataset.lower() == 'cifar10':

        from utils import cifar10loaders
        loaders = cifar10loaders()

    elif dataset.lower() == 'svhn':

        from utils import svhnLoaders
        loaders = svhnLoaders()
    else:
        ValueError('Datasets to choose from: CIFAR10 and SVHN')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:
        set_seed(seed)

        if log:
            with open(logfile, 'a') as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]
        net = WideResNet(d=wrn_depth,
                         k=wrn_width,
                         n_classes=10,
                         input_features=3,
                         output_features=16,
                         strides=strides)
        net = net.to(device)

        checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth, wrn_width, dataset, seed) if checkpoint else ''
        best_test_set_accuracy = _train_seed(net, loaders, device, dataset,
                                             log, checkpoint, logfile,
                                             checkpointFile)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(logfile, 'a') as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))