Пример #1
0
def naive_ensembling(args, networks, test_loader):
    # simply average the weights in networks
    if args.width_ratio != 1:
        print(
            "Unfortunately naive ensembling can't work if models are not of same shape!"
        )
        return -1, None
    weights = [(1 - args.ensemble_step), args.ensemble_step]
    avg_pars = get_avg_parameters(networks, weights)
    ensemble_network = get_model_from_name(args)
    # put on GPU
    if args.gpu_id != -1:
        ensemble_network = ensemble_network.cuda(args.gpu_id)

    # check the test performance of the method before
    log_dict = {}
    log_dict['test_losses'] = []
    # log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)]
    routines.test(args, ensemble_network, test_loader, log_dict)

    # set the weights of the ensembled network
    for idx, (name, param) in enumerate(ensemble_network.state_dict().items()):
        ensemble_network.state_dict()[name].copy_(avg_pars[idx].data)

    # check the test performance of the method after ensembling
    log_dict = {}
    log_dict['test_losses'] = []
    # log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)]
    return routines.test(args, ensemble_network, test_loader,
                         log_dict), ensemble_network
Пример #2
0
def distillation(args, teachers, student, train_loader, test_loader, device):
    # Inspiration: https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/evaluate.py

    for teacher in teachers:
        teacher.eval()

    optimizer = optim.SGD(student.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum)

    log_dict = {}
    log_dict['train_losses'] = []
    log_dict['train_counter'] = []
    log_dict['test_losses'] = []

    accuracies = []
    accuracies.append(routines.test(args, student, test_loader, log_dict))
    for epoch_idx in range(0, args.dist_epochs):
        student.train()

        for batch_idx, (data_batch, labels_batch) in enumerate(train_loader):

            # move to GPU if available
            if args.gpu_id != -1:
                data_batch, labels_batch = data_batch.to(
                    device), labels_batch.to(device)

            # compute mean teacher output
            teacher_outputs = []
            for teacher in teachers:
                teacher_outputs.append(teacher(data_batch,
                                               disable_logits=True))
            teacher_outputs = torch.stack(teacher_outputs)
            teacher_outputs = teacher_outputs.mean(dim=0)
            optimizer.zero_grad()
            # get student output
            student_output = student(data_batch, disable_logits=True)

            # knowledge distillation loss
            loss = loss_fn_kd(student_output, labels_batch, teacher_outputs,
                              args)
            loss.backward()
            # update student
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch_idx, batch_idx * len(data_batch),
                    len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
                log_dict['train_losses'].append(loss.item())
                log_dict['train_counter'].append((batch_idx * 64) + (
                    (epoch_idx - 1) * len(train_loader.dataset)))

        accuracies.append(routines.test(args, student, test_loader, log_dict))

    return student, accuracies
Пример #3
0
def update_model(args, model, new_params, test=False, test_loader=None, reversed=False, idx=-1):

    updated_model = get_model_from_name(args, idx=idx)
    if args.gpu_id != -1:
        updated_model = updated_model.cuda(args.gpu_id)

    layer_idx = 0
    model_state_dict = model.state_dict()

    print("len of model_state_dict is ", len(model_state_dict.items()))
    print("len of new_params is ", len(new_params))

    for key, value in model_state_dict.items():
        print("updated parameters for layer ", key)
        model_state_dict[key] = new_params[layer_idx]
        layer_idx += 1
        if layer_idx == len(new_params):
            break


    updated_model.load_state_dict(model_state_dict)

    if test:
        log_dict = {}
        log_dict['test_losses'] = []
        final_acc = routines.test(args, updated_model, test_loader, log_dict)
        print("accuracy after update is ", final_acc)
    else:
         final_acc = None

    return updated_model, final_acc
Пример #4
0
def recheck_accuracy(args, models, test_loader):
    # Additional flag of recheck_acc to supplement the legacy flag recheck_cifar
    if args.recheck_cifar or args.recheck_acc:
        recheck_accuracies = []
        for model in models:
            log_dict = {}
            log_dict['test_losses'] = []
            recheck_accuracies.append(
                routines.test(args, model, test_loader, log_dict))
        print("Rechecked accuracies are ", recheck_accuracies)
Пример #5
0
def get_network_from_param_list(args, param_list, test_loader):

    print("using independent method")
    new_network = get_model_from_name(args, idx=1)
    if args.gpu_id != -1:
        new_network = new_network.cuda(args.gpu_id)

    # check the test performance of the network before
    log_dict = {}
    log_dict['test_losses'] = []
    routines.test(args, new_network, test_loader, log_dict)

    # set the weights of the new network
    # print("before", new_network.state_dict())
    print("len of model parameters and avg aligned layers is ", len(list(new_network.parameters())),
          len(param_list))
    assert len(list(new_network.parameters())) == len(param_list)

    layer_idx = 0
    model_state_dict = new_network.state_dict()

    print("len of model_state_dict is ", len(model_state_dict.items()))
    print("len of param_list is ", len(param_list))

    for key, value in model_state_dict.items():
        model_state_dict[key] = param_list[layer_idx]
        layer_idx += 1

    new_network.load_state_dict(model_state_dict)

    # check the test performance of the network after
    log_dict = {}
    log_dict['test_losses'] = []
    acc = routines.test(args, new_network, test_loader, log_dict)

    return acc, new_network
Пример #6
0
def test_model(args, model, test_loader):
    log_dict = {}
    log_dict['test_losses'] = []
    return routines.test(args, model, test_loader, log_dict)
Пример #7
0
            else:
                model, accuracy, local_accuracy = routines.get_pretrained_model(
                        args, os.path.join(ensemble_dir, 'model_{}/{}.checkpoint'.format(idx, args.ckpt_type)), data_separated=True, idx = idx
                )
            models.append(model)
            accuracies.append(accuracy)
            local_accuracies.append(local_accuracy)
        print("Done loading all the models")

        # Additional flag of recheck_acc to supplement the legacy flag recheck_cifar
        if args.recheck_cifar or args.recheck_acc:
            recheck_accuracies = []
            for model in models:
                log_dict = {}
                log_dict['test_losses'] = []
                recheck_accuracies.append(routines.test(args, model, test_loader, log_dict))
            print("Rechecked accuracies are ", recheck_accuracies)

        # print('checking named modules of model0 for use in compute_activations!', list(models[0].named_modules()))

    else:
        # get dataloaders
        print("------- Obtain dataloaders -------")
        train_loader, test_loader = get_dataloader(args)

        if args.partition_type == 'labels':
            print("------- Split dataloaders by labels -------")
            choice = [int(x) for x in args.choice.split()]
            (trailo_a, teslo_a), (trailo_b, teslo_b), other = partition.split_mnist_by_labels(args, train_loader, test_loader, choice=choice)
            print("------- Training independent models -------")
            models, accuracies, local_accuracies = routines.train_data_separated_models(args, [trailo_a, trailo_b],
Пример #8
0
                        ensemble_dir,
                        'model_{}/{}.checkpoint'.format(idx, args.ckpt_type)),
                    idx=idx)

            models.append(model)
            accuracies.append(accuracy)
        print("Done loading all the models")

        # Additional flag of recheck_acc to supplement the legacy flag recheck_cifar
        if args.recheck_cifar or args.recheck_acc:
            recheck_accuracies = []
            for model in models:
                log_dict = {}
                log_dict['test_losses'] = []
                recheck_accuracies.append(
                    routines.test(args, model, test_loader, log_dict))
            print("Rechecked accuracies are ", recheck_accuracies)

        # print('checking named modules of model0 for use in compute_activations!', list(models[0].named_modules()))

        # print('what about named parameters of model0 for use in compute_activations!', [tupl[0] for tupl in list(models[0].named_parameters())])

    else:
        # get dataloaders
        print("------- Obtain dataloaders -------")
        train_loader, test_loader = get_dataloader(args)
        retrain_loader, _ = get_dataloader(
            args, no_randomness=args.no_random_trainloaders)

        print("------- Training independent models -------")
        models, accuracies = routines.train_models(args, train_loader,