コード例 #1
0
    def _load_models_from_disk(self):
        # load the pre-trained models
        sub_net_list = []
        for n_chk, chk_name in enumerate(self.subs_chk_name):
            for snet in self.substitute_nets:
                if args.subs_dp[n_chk] > 0.0:
                    net = load_pretrained_net(snet, chk_name, model_chk_path=self.model_resume_path,
                                              test_dp=self.subs_dp[n_chk])
                elif self.subs_dp[n_chk] == 0.0:
                    net = load_pretrained_net(snet, chk_name, model_chk_path=self.model_resume_path)
                else:
                    assert False
                sub_net_list.append(net)

        self.models = sub_net_list
コード例 #2
0
                args.train_data_path,
                subset='others',
                transform=transform_train,
                num_per_label=args.num_per_class,
                poison_tuple_list=poison_tuple_list,
                poison_indices=base_idx_list,
                subset_group=args.subset_group)
            print("Poisoned dataset created")
            res[ite] = get_stats(
                '{}/{}/log.txt'.format(args.eval_poisons_root, target_idx),
                state_dict, ite - 1)

            res[ite]['victims'] = {}
            for victim_name in args.target_net:
                print(victim_name)
                victim_net = load_pretrained_net(
                    victim_name,
                    args.test_chk_name,
                    model_chk_path=args.model_resume_path,
                    device=args.device)
                res[ite]['victims'][victim_name] = \
                    train_network_with_poison(victim_net, target, target_camera_tuple, poison_tuple_list,
                                              poisoned_dset, base_idx_list, args, testset,
                                              savemodel='{}/{}-poison-ites-{}'.format(models_dir, victim_name, ite))
        if no_save:
            all_res['targets'][target_idx] = res
            all_res['poison_idx_list'] = base_idx_list

            with open(json_res_path, 'w') as f:
                json.dump(all_res, f)
コード例 #3
0
    parser.add_argument('--device', default='cuda', type=str)
    args = parser.parse_args()

    # Set visible CUDA devices
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cudnn.benchmark = True

    # load the pre-trained models
    sub_net_list = []
    for n_chk, chk_name in enumerate(args.subs_chk_name):
        for snet in args.substitute_nets:
            if args.subs_dp[n_chk] > 0.0:
                net = load_pretrained_net(
                    snet,
                    chk_name,
                    model_chk_path=args.model_resume_path,
                    test_dp=args.subs_dp[n_chk])
            elif args.subs_dp[n_chk] == 0.0:
                net = load_pretrained_net(
                    snet, chk_name, model_chk_path=args.model_resume_path)
            else:
                assert False
            sub_net_list.append(net)

    print("subs nets, effective num: {}".format(len(sub_net_list)))

    print("Loading the victims networks")
    targets_net = []
    for tnet in args.target_net:
        target_net = load_pretrained_net(tnet,
コード例 #4
0
def train_model(model, dataloader, criteria, optimizers, schedulers,
                num_epochs, params):

    # Note the time
    since = time.time()

    # Unpack parameters
    writer = params['writer']
    if writer is not None: board = True
    txt_file = params['txt_file']
    pretrained = params['model_files'][1]
    pretrain = params['pretrain']
    print_freq = params['print_freq']
    dataset_size = params['dataset_size']
    device = params['device']
    batch = params['batch']
    pretrain_epochs = params['pretrain_epochs']
    gamma = params['gamma']
    update_interval = params['update_interval']
    tol = params['tol']

    dl = dataloader

    # Pretrain or load weights
    if pretrain:
        while True:
            pretrained_model = pretraining(model, copy.deepcopy(dl),
                                           criteria[0], optimizers[1],
                                           schedulers[1], pretrain_epochs,
                                           params)
            if pretrained_model:
                break
            else:
                for layer in model.children():
                    if hasattr(layer, 'reset_parameters'):
                        layer.reset_parameters()
        model = pretrained_model
    else:
        try:
            utils.load_pretrained_net(model, pretrained)
            utils.print_both(
                txt_file,
                'Pretrained weights loaded from file: ' + str(pretrained))
        except:
            print("Couldn't load pretrained weights")

    # Initialise clusters
    if params['train_init_clusters'] or pretrain:
        init_clusters(txt_file, model, dl, params)
    utils.print_both(txt_file, '\nBegin clusters training')

    # Prep variables for weights and accuracy of the best model
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10000.0

    # Initial target distribution
    utils.print_both(txt_file, '\nUpdating target distribution')
    output_distribution, labels, preds_prev, embedding, label_img = calculate_predictions(
        model, copy.deepcopy(dl), params)
    if board:
        writer.add_embedding(embedding,
                             metadata=labels,
                             global_step=0,
                             label_img=label_img,
                             tag='embedding_layer')
        writer.add_embedding(output_distribution,
                             metadata=labels,
                             global_step=0,
                             label_img=label_img,
                             tag='clustering_output')
    target_distribution = target(output_distribution)
    if params['class_dependent_metrics']:
        nmi = utils.metrics.nmi(labels, preds_prev)
        ari = utils.metrics.ari(labels, preds_prev)
        acc = utils.metrics.acc(labels, preds_prev)
        utils.print_both(
            txt_file,
            'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\n'.format(nmi, ari, acc))

        if board:
            niter = 0
            writer.add_scalar('/NMI', nmi, niter)
            writer.add_scalar('/ARI', ari, niter)
            writer.add_scalar('/Acc', acc, niter)

    update_iter = 1
    finished = False

    # Go through all epochs
    for epoch in range(num_epochs):
        if epoch < params['zero_gamma_epochs']:
            gamma = 0
        else:
            gamma = params['gamma']

        utils.print_both(txt_file, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
        utils.print_both(txt_file, '-' * 10)

        model.train(True)  # Set model to training mode

        running_loss = 0.0
        running_loss_rec = 0.0
        running_loss_clust = 0.0

        # Keep the batch number for inter-phase statistics
        batch_num = 1
        img_counter = 0

        # Iterate over data.
        for data in dataloader:
            # Get the inputs and labels
            inputs, provided_labels = data

            inputs = inputs.to(device)

            # Uptade target distribution, chack and print performance
            if (batch_num - 1) % update_interval == 0 and not (batch_num == 1
                                                               and epoch == 0):
                utils.print_both(txt_file, '\nUpdating target distribution:')
                output_distribution, labels, preds, embedding, label_img = calculate_predictions(
                    model, dataloader, params)
                if (board and batch_num == 1
                        and epoch % params['embedding_interval'] == 0):
                    writer.add_embedding(embedding,
                                         metadata=labels,
                                         global_step=epoch,
                                         label_img=label_img,
                                         tag='embedding_layer')
                    writer.add_embedding(output_distribution,
                                         metadata=labels,
                                         global_step=epoch,
                                         label_img=label_img,
                                         tag='clustering_output')

                target_distribution = target(output_distribution)
                if params['class_dependent_metrics']:
                    nmi = utils.metrics.nmi(labels, preds)
                    ari = utils.metrics.ari(labels, preds)
                    acc = utils.metrics.acc(labels, preds)
                    utils.print_both(
                        txt_file,
                        'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\t'.format(
                            nmi, ari, acc))
                    if board:
                        niter = update_iter
                        writer.add_scalar('/NMI', nmi, niter)
                        writer.add_scalar('/ARI', ari, niter)
                        writer.add_scalar('/Acc', acc, niter)
                update_iter += 1

                # check stop criterion
                delta_label = np.sum(preds != preds_prev).astype(
                    np.float32) / preds.shape[0]
                preds_prev = np.copy(preds)
                if delta_label < tol:
                    utils.print_both(
                        txt_file, 'Label divergence ' + str(delta_label) +
                        '< tol ' + str(tol))
                    utils.print_both(
                        txt_file,
                        'Reached tolerance threshold. Stopping training.')
                    finished = True
                    break

            tar_dist = target_distribution[((batch_num - 1) *
                                            batch):(batch_num * batch), :]
            tar_dist = torch.from_numpy(tar_dist).to(device)
            # print(tar_dist)

            # zero the parameter gradients
            optimizers[0].zero_grad()

            # Calculate losses and backpropagate
            with torch.set_grad_enabled(True):
                outputs, clusters, _ = model(inputs)
                loss_rec = criteria[0](outputs, inputs)
                loss_clust = criteria[1](torch.log(clusters), tar_dist) / batch
                if (params['DEC']):
                    loss = gamma * loss_clust
                else:
                    loss = loss_rec + gamma * loss_clust
                loss.backward()
                optimizers[0].step()

            # For keeping statistics
            running_loss += loss.item() * inputs.size(0)
            running_loss_rec += loss_rec.item() * inputs.size(0)
            running_loss_clust += loss_clust.item() * inputs.size(0)

            # Some current stats
            loss_batch = loss.item()
            loss_batch_rec = loss_rec.item()
            loss_batch_clust = loss_clust.item()
            loss_accum = running_loss / (
                (batch_num - 1) * batch + inputs.size(0))
            loss_accum_rec = running_loss_rec / (
                (batch_num - 1) * batch + inputs.size(0))
            loss_accum_clust = running_loss_clust / (
                (batch_num - 1) * batch + inputs.size(0))

            if batch_num % print_freq == 0:
                utils.print_both(
                    txt_file, 'Epoch: [{0}][{1}/{2}]\t'
                    'Loss {3:.4f} ({4:.4f})\t'
                    'Loss_recovery {5:.4f} ({6:.4f})\t'
                    'Loss clustering {7:.4f} ({8:.4f})\t'.format(
                        epoch + 1, batch_num, len(dataloader), loss_batch,
                        loss_accum, loss_batch_rec, loss_accum_rec,
                        loss_batch_clust, loss_accum_clust))
                if board:
                    niter = epoch * len(dataloader) + batch_num
                    writer.add_scalar('/Loss', loss_accum, niter)
                    writer.add_scalar('/Loss_recovery', loss_accum_rec, niter)
                    writer.add_scalar('/Loss_clustering', loss_accum_clust,
                                      niter)
            batch_num = batch_num + 1

            # Print image to tensorboard
            if batch_num == len(dataloader) and (epoch + 1) % 5:
                inp = utils.tensor2img(inputs)
                out = utils.tensor2img(outputs)
                if board:
                    img = np.concatenate((inp, out), axis=1)
                    writer.add_image(
                        'Clustering/Epoch_' + str(epoch + 1).zfill(3) +
                        '/Sample_' + str(img_counter).zfill(2), img)
                    img_counter += 1

        schedulers[0].step()
        if finished: break

        epoch_loss = running_loss / dataset_size
        epoch_loss_rec = running_loss_rec / dataset_size
        epoch_loss_clust = running_loss_clust / dataset_size

        if board:
            writer.add_scalar('/Loss' + '/Epoch', epoch_loss, epoch + 1)
            writer.add_scalar('/Loss_rec' + '/Epoch', epoch_loss_rec,
                              epoch + 1)
            writer.add_scalar('/Loss_clust' + '/Epoch', epoch_loss_clust,
                              epoch + 1)

        utils.print_both(
            txt_file,
            'Loss: {0:.4f}\tLoss_recovery: {1:.4f}\tLoss_clustering: {2:.4f}'.
            format(epoch_loss, epoch_loss_rec, epoch_loss_clust))

        # If wanted to do some criterium in the future (for now useless)
        if epoch_loss < best_loss or epoch_loss >= best_loss:
            best_loss = epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())

        utils.print_both(txt_file, '')

    model.eval()
    output_distribution, labels, preds, embedding, label_img = calculate_predictions(
        model, dataloader, params)
    if params['class_dependent_metrics']:
        nmi = utils.metrics.nmi(labels, preds)
        ari = utils.metrics.ari(labels, preds)
        acc = utils.metrics.acc(labels, preds)
        utils.print_both(
            txt_file,
            'NMI: {0:.5f}\tARI: {1:.5f}\tAcc {2:.5f}\t'.format(nmi, ari, acc))
        niter = update_iter
        writer.add_scalar('/NMI', nmi, niter)
        writer.add_scalar('/ARI', ari, niter)
        writer.add_scalar('/Acc', acc, niter)
    if board:
        writer.add_embedding(embedding,
                             metadata=labels,
                             global_step=num_epochs,
                             label_img=label_img,
                             tag='embedding_layer')
        writer.add_embedding(output_distribution,
                             metadata=labels,
                             global_step=num_epochs,
                             label_img=label_img,
                             tag='clustering_output')

    log_func = lambda x: utils.print_both(txt_file, x)

    if params['use_ssim']:
        ssim_metrics = utils.matrix_metrics(params['ssim_matrix'],
                                            params['num_clusters'], preds,
                                            labels)
        utils.log_matrix_metrics(log_func, ssim_metrics, 'SSIM')

    if params['use_mse']:
        mse_metrics = utils.matrix_metrics(params['mse_matrix'],
                                           params['num_clusters'], preds,
                                           labels)
        utils.log_matrix_metrics(log_func, mse_metrics, 'MSE')

    #if params['use_ssim']:
    #    # masks out self-pairs -- including these would increase the
    #    # average ssim for in-cluster pairs.
    #    self_pair_mask = 1 - np.identity(len(preds))
    #    total_sum_ssim_in = 0
    #    total_num_pairs_in = 0
    #    total_sum_ssim_out = 0
    #    total_num_pairs_out = 0
    #
    #    # x and y are used to select pairs from the matrix. Each element of
    #    # x and y is a mask representing the images in (in the case of x) or out
    #    # (in the case of y) of a cluster. These masks are 2-dimensional so that
    #    # the transpose operation can be used.
    #    #
    #    # this will be 1 for each element predicted in class, 0 otherwise
    #    x = np.zeros((params['num_clusters'], 1, len(preds)))
    #    # this will be 0 for each element predicted in class, 1 otherwise
    #    y = np.ones((params['num_clusters'], 1, len(preds)))

    #    encountered_predictions = np.zeros(params['num_clusters'])
    #    for i in range(len(preds)):
    #        encountered_predictions[preds[i]] += 1
    #        predicted_class_index = preds[i]
    #        x[predicted_class_index][0][labels[i]] = 1
    #        y[predicted_class_index][0][labels[i]] = 0
    #    utils.print_both(txt_file, f'Predictions per cluster: {encountered_predictions}')
    #    for i in range(params['num_clusters']):
    #        if encountered_predictions[i] == 0:
    #            utils.print_both(txt_file, f'WARNING: No inputs predicted to exist withing cluster {i}.')
    #        # select in-cluster pairs
    #        pairs_in_mask = x[i] * x[i].transpose() * self_pair_mask
    #        pairs_in = params['ssim_matrix'] * pairs_in_mask
    #        num_pairs_in = sum(sum(pairs_in_mask > 0))
    #        # select pairs with one image in the cluster and one image not in the cluster
    #        pairs_out_mask = x[i] * y[i].transpose() + x[i].transpose() * y[i]
    #        pairs_out = params['ssim_matrix'] * pairs_out_mask
    #        num_pairs_out = sum(sum(pairs_out_mask > 0))

    #        sum_ssim_in = sum(sum(pairs_in))
    #        sum_ssim_out = sum(sum(pairs_out))
    #        avg_ssim_in = sum_ssim_in/num_pairs_in
    #        avg_ssim_out = sum_ssim_out/num_pairs_out

    #        total_sum_ssim_in += sum_ssim_in
    #        total_sum_ssim_out += sum_ssim_out
    #        total_num_pairs_in += num_pairs_in
    #        total_num_pairs_out += num_pairs_out

    #        utils.print_both(txt_file, f'Cluster {i}: Average SSIM (in cluster): {avg_ssim_in}')
    #        utils.print_both(txt_file, f'Cluster {i}: Average SSIM (out cluster): {avg_ssim_out}')
    #    total_avg_ssim_in = total_sum_ssim_in / total_num_pairs_in
    #    total_avg_ssim_out = total_sum_ssim_out / total_num_pairs_out
    #    utils.print_both(txt_file, f'SSIM (in cluster): {total_avg_ssim_in}')
    #    utils.print_both(txt_file, f'SSIM (out cluster): {total_avg_ssim_out}')

    update_iter += 1

    time_elapsed = time.time() - since
    utils.print_both(
        txt_file,
        'Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,
                                                      time_elapsed % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
コード例 #5
0
                             "CP loss to multiple layers! ")
    parser.add_argument('--device', default='cuda', type=str)
    args = parser.parse_args()

    if args.retrain_subs_nets:
        assert args.end2end

    # Set visible CUDA devices
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cudnn.benchmark = True

    nets = []
    for n_chk, chk_name in enumerate(args.subs_chk_name):
        for snet in args.substitute_nets:
            net = load_pretrained_net(snet, chk_name, model_chk_path=args.model_resume_path)
            nets.append(net)

    subs_nets = SubstituteNets(args.model_resume_path, args.subs_chk_name, args.substitute_nets, args.subs_dp)

    print("Loading the victims networks")
    targets_net = []
    for tnet in args.target_net:
        target_net = load_pretrained_net(tnet, args.test_chk_name, model_chk_path=args.model_resume_path)
        targets_net.append(target_net)

    cifar_mean = (0.4914, 0.4822, 0.4465)
    cifar_std = (0.2023, 0.1994, 0.2010)
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
コード例 #6
0
        os.environ["CUDA_VISIBLE_DEVICES"] = ""

    cifar_mean = (0.4914, 0.4822, 0.4465)
    cifar_std = (0.2023, 0.1994, 0.2010)
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])

    # load the pre-trained models
    sub_net_list = []
    for n_chk, chk_name in enumerate(args.subs_chk_name):
        for snet in args.substitute_nets:
            net = load_pretrained_net(snet,
                                      chk_name,
                                      model_chk_path=args.model_resume_path,
                                      test_dp=args.subs_dp[n_chk],
                                      device=args.device)
            sub_net_list.append(net)

    for target_idx in range(0, 80):
        print("computing the coeffs for target {}".format(target_idx))

        target = fetch_target(args.target_label,
                              target_idx,
                              50,
                              subset='others',
                              path=args.train_data_path,
                              transforms=transform_test)
        json_res_path = '{}/{}/eval-retrained-for-{}epochs.json'.format(
            args.eval_poisons_root, target_idx, args.retrain_epochs)
コード例 #7
0
                                    gamma=sched_gamma)
    scheduler_pretrain = lr_scheduler.StepLR(optimizer_pretrain,
                                             step_size=sched_step_pretrain,
                                             gamma=sched_gamma_pretrain)

    schedulers = [scheduler, scheduler_pretrain]

    utils.print_both(f, 'Mode: {}\n'.format(args.mode))

    if args.mode == 'train_full':
        model = training_functions.train_model(model, dataloader, criteria,
                                               optimizers, schedulers, epochs,
                                               params)
    elif args.mode == 'pretrain':
        model = training_functions.pretraining(model, dataloader, criteria[0],
                                               optimizers[1], schedulers[1],
                                               epochs, params)
        training_functions.init_clusters(f, model, dataloader, params)
    elif args.mode == 'init_clusters':
        utils.load_pretrained_net(model, args.pretrained_net)
        training_functions.init_clusters(f, model, dataloader, params)

    # Save final model
    torch.save(model.state_dict(), name_net)
    print('Saved model to {}'.format(name_net))

    # Close files
    f.close()
    if board:
        writer.close()