Exemple #1
0
def test(net, criterion, logfile, loader, device):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(
            batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (test_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))

    with open(logfile, 'a') as f:
        f.write('Test results:\n')
        f.write('Loss: %.3f | Acc: %.3f%% (%d/%d)\n' %
                (test_loss /
                 (batch_idx + 1), 100. * correct / total, correct, total))
    # return the acc.
    return 100. * correct / total
Exemple #2
0
def train_teacher(epoch, net, criterion, optimizer, use_cuda, logfile, loader,
                  wmloader):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    iteration = -1

    # get the watermark images
    wminputs, wmtargets = [], []
    if wmloader:
        for wm_idx, (wminput, wmtarget) in enumerate(wmloader):
            if use_cuda:
                wminput, wmtarget = wminput.cuda(), wmtarget.cuda()
            wminputs.append(wminput)
            wmtargets.append(wmtarget)
        # the wm_idx to start from
        wm_idx = np.random.randint(len(wminputs))

    for batch_idx, (inputs, targets) in enumerate(loader):
        iteration += 1
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        if wmloader:
            # add wmimages and targets
            inputs = torch.cat(
                [inputs, wminputs[(wm_idx + batch_idx) % len(wminputs)]],
                dim=0)
            targets = torch.cat(
                [targets, wmtargets[(wm_idx + batch_idx) % len(wminputs)]],
                dim=0)

        inputs, targets = Variable(inputs), Variable(targets)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(
            batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * float(correct) / total, correct, total))

    with open(logfile, 'a') as f:
        f.write('Epoch: %d\n' % epoch)
        f.write(
            'Loss: %.3f | Acc: %.3f%% (%d/%d)\n' %
            (train_loss /
             (batch_idx + 1), 100. * float(correct) / total, correct, total))
Exemple #3
0
def test_path(path):
    assert os.path.exists(path), 'Error: no checkpoint found!'
    print('==> Resuming from checkpoint..')
    checkpoint = torch.load(path)
    net = checkpoint['net']
    acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net,
                                    device_ids=range(
                                        torch.cuda.device_count()))
        cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss()

    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)

        #     print ("targets", targets)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        #     print ("outputs", predicted)

        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        #        if args.testwm:
        #            print (np.where(predicted.eq(targets.data)))

        progress_bar(
            batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (test_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
Exemple #4
0
def train(epoch,
          net,
          criterion,
          optimizer,
          logfile,
          loader,
          device,
          wmloader=False,
          tune_all=True):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    iteration = -1
    wm_correct = 0
    print_every = 5
    l_lambda = 1.2

    # update only the last layer
    if not tune_all:
        if type(net) is torch.nn.DataParallel:
            net.module.freeze_hidden_layers()
        else:
            net.freeze_hidden_layers()

    # get the watermark images
    wminputs, wmtargets = [], []
    if wmloader:
        for wm_idx, (wminput, wmtarget) in enumerate(wmloader):
            wminput, wmtarget = wminput.to(device), wmtarget.to(device)
            wminputs.append(wminput)
            wmtargets.append(wmtarget)

        # the wm_idx to start from
        wm_idx = np.random.randint(len(wminputs))
    for batch_idx, (inputs, targets) in enumerate(loader):
        iteration += 1
        inputs, targets = inputs.to(device), targets.to(device)

        # add wmimages and targets
        if wmloader:
            inputs = torch.cat(
                [inputs, wminputs[(wm_idx + batch_idx) % len(wminputs)]],
                dim=0)
            targets = torch.cat(
                [targets, wmtargets[(wm_idx + batch_idx) % len(wminputs)]],
                dim=0)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(
            batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))

    with open(logfile, 'a') as f:
        f.write('Epoch: %d\n' % epoch)
        f.write('Loss: %.3f | Acc: %.3f%% (%d/%d)\n' %
                (train_loss /
                 (batch_idx + 1), 100. * correct / total, correct, total))
Exemple #5
0
def train_steal(epoch,
                net,
                parent,
                optimizer,
                logfile,
                loader,
                device,
                grad_query=True):
    print('\nEpoch: %d' % epoch)
    net.train()
    parent.eval()
    train_loss = 0
    progress_mem_loss = 0
    progress_grad_loss = 0
    correct = 0
    total = 0
    iteration = -1
    wm_correct = 0
    print_every = 5
    l_lambda = 1.2

    pseudo_label_criterion = torch.nn.CrossEntropyLoss()
    mse_criterion = torch.nn.MSELoss(size_average=False)

    for batch_idx, (inputs, targets) in enumerate(loader):
        iteration += 1
        inputs = inputs.to(device)
        inputs.requires_grad = True
        targets = targets.to(device)
        batch_size, n_channels, ny, nx = inputs.shape
        n_classes, = targets.shape

        # Parent Computations
        parent_gradients = []

        target_logits = parent(inputs)
        _, pseudo_labels = torch.max(target_logits.data, 1)

        for l in range(10):
            out_l = target_logits[:, l].sum()
            l_gradient = torch.cat(
                torch.autograd.grad(out_l, inputs, create_graph=True))
            parent_gradients.append(l_gradient)

        parent_gradients = torch.stack(parent_gradients, dim=1)

        # Child Computations
        child_gradients = []

        outputs = net(inputs)

        for l in range(10):
            out_l = outputs[:, l].sum()
            l_gradient = torch.cat(
                torch.autograd.grad(out_l, inputs, create_graph=True))
            child_gradients.append(l_gradient)

        child_gradients = torch.stack(child_gradients, dim=1)

        membership_loss = pseudo_label_criterion(outputs, pseudo_labels)
        gradient_loss = mse_criterion(parent_gradients,
                                      child_gradients) / batch_size
        lam = 1 if grad_query else 0
        loss = membership_loss + lam * gradient_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_mem_loss += membership_loss.item()
        progress_grad_loss += gradient_loss.item()
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(
            batch_idx, len(loader),
            'Gradient Loss: %.3f | Membership Loss: %.3f | True Train Acc: %.3f%% (%d/%d)'
            % (progress_grad_loss / (batch_idx + 1), progress_mem_loss /
               (batch_idx + 1), 100. * correct / total, correct, total))

    with open(logfile, 'a') as f:
        f.write('Epoch: %d\n' % epoch)
        f.write('Loss: %.3f | True Acc: %.3f%% (%d/%d)\n' %
                (train_loss /
                 (batch_idx + 1), 100. * correct / total, correct, total))
Exemple #6
0
def train(epoch,
          net,
          criterion,
          optimizer,
          logfile,
          loader,
          device,
          wmloader=False,
          tune_all=True,
          ex_datas=[],
          ex_net=None,
          wm2_loader=None,
          n_classes=None,
          EWC_coef=0.,
          Fisher=None,
          init_params=None,
          EWC_immune=[],
          afs_bsize=0,
          extra_only=False):
    print('\nEpoch: %d' % epoch)

    net.train()
    train_loss = 0
    train_loss_wm = 0
    correct = 0
    total = 0
    iteration = -1
    wm_correct = 0
    print_every = 5
    l_lambda = 1.2

    # update only the last layer
    if not tune_all:
        if type(net) is torch.nn.DataParallel:
            net.module.freeze_hidden_layers()
        else:
            net.freeze_hidden_layers()

    # get the watermark images

    wminputs, wmtargets = [], []
    if wmloader:
        for wm_idx, (wminput, wmtarget) in enumerate(wmloader):
            wminput, wmtarget = wminput.to(device), wmtarget.to(device)
            wminputs.append(wminput)
            wmtargets.append(wmtarget)

        # the wm_idx to start from
        wm_idx = np.random.randint(len(wminputs))

    if afs_bsize > 0:
        afs_idx = 0

    for batch_idx, (inputs, targets) in enumerate(loader):
        iteration += 1
        inputs, targets = inputs.to(device), targets.to(device)

        # add wmimages and targets
        if wmloader:
            inputs = torch.cat(
                [inputs, wminputs[(wm_idx + batch_idx) % len(wminputs)]],
                dim=0)
            targets = torch.cat(
                [targets, wmtargets[(wm_idx + batch_idx) % len(wminputs)]],
                dim=0)

        if afs_bsize > 0:
            inputs = torch.cat(
                [inputs, net.afs_inputs[afs_idx:afs_idx + afs_bsize]], dim=0)
            targets = torch.cat(
                [targets, net.afs_targets[afs_idx:afs_idx + afs_bsize]], dim=0)
            afs_idx = (afs_idx + afs_bsize) % net.afs_inputs.size(0)

        # add data from extra sources
        original_batch_size = targets.size(0)
        extra_only_tag = True
        for _loader in ex_datas:
            _input, _target = next(_loader)
            _input, _target = _input.to(device), _target.to(device)
            if _target[0].item() < -1:
                with torch.no_grad():
                    _, __target = torch.max(ex_net(_input).data, 1)
                    _target = (__target + _target + 20000) % n_classes
            elif _target[0].item() == -1 or ex_net != None:
                with torch.no_grad():
                    _output = ex_net(_input)

                    _, _target = torch.max(_output.data, 1)
                    _target = _target.to(device)

            if extra_only and extra_only_tag:
                inputs = _input
                targets = _target
                extra_only_tag = False
            else:
                inputs = torch.cat([inputs, _input], dim=0)
                targets = torch.cat([targets, _target], dim=0)

        outputs = net(inputs)
        loss = criterion(outputs, targets)

        if EWC_coef > 0:
            for param, fisher, init_param in zip(net.parameters(), Fisher,
                                                 init_params):
                if IsInside(param, EWC_immune):
                    continue
                loss = loss + (0.5 * EWC_coef * fisher.clamp(
                    max=1. / optimizer.param_groups[0]['lr'] / EWC_coef) *
                               ((param - init_param)**2)).sum()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)

        correct += predicted.eq(targets.data).cpu().sum()

        progress_bar(
            batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * float(correct) / total, correct, total))

    with open(logfile, 'a') as f:
        f.write('Epoch: %d\n' % epoch)
        f.write(
            'Loss: %.3f | Acc: %.3f%% (%d/%d)\n' %
            (train_loss /
             (batch_idx + 1), 100. * float(correct) / total, correct, total))
start_epoch = checkpoint['epoch']

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True
criterion = nn.CrossEntropyLoss()

net.eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(loader):
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = net(inputs)

    #     print ("targets", targets)
    loss = criterion(outputs, targets)

    test_loss += loss.item()
    _, predicted = torch.max(outputs.data, 1)
    #     print ("outputs", predicted)

    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(
        batch_idx, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
        (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
Exemple #8
0
    def Evaluate(self, steps=250):
        #set mode and wait for the threads to populate the queue
        self.valid_iterator.Reset()
        self.summary_manager.mode = 'valid'
        time.sleep(5.0)

        feed = None
        valid_ops = [
            self.model.logits, self.model.summary_ops['1step'],
            self.model.inference_ops, self.model.accumulate_ops
        ]

        #iterate the validation step, until count = steps
        count = 0
        step_sec = 0
        if steps <= 0:
            steps = self.valid_iterator.steps
        while (count < steps):
            start_time = time.time()
            time.sleep(0.4)

            input_batch, target_batch, weight_batch, batch_class_weights, _ = self.valid_iterator.GetNextBatch(
            )

            if type(input_batch) is np.ndarray:

                feed = {
                    self.model.inputs: input_batch,
                    self.model.targets: target_batch,
                    self.model.weight_maps: weight_batch,
                    self.model.batch_class_weights: batch_class_weights,
                    self.model.is_training: False
                }
                input_batch, target_batch, weight_batch, batch_class_weights = None, None, None, None

                outputs, summary, _, _ = self.sess.run(valid_ops,
                                                       feed_dict=feed)
                self.summary_manager.AddSummary(summary, "valid", "per_step")
                progress_bar(count % steps + 1, steps, step_sec)
                outputs, summary = None, None

                # add summaries regularly for every 100 steps or no. of steps to finish an epoch
                if self.valid_iterator.steps <= 100:
                    save_step = self.valid_iterator.steps // 2
                else:
                    save_step = 100

                if (count + 1) % save_step == 0:
                    summary = self.sess.run(self.model.summary_ops['100steps'],
                                            feed_dict=feed)
                    self.summary_manager.AddSummary(summary, "valid",
                                                    "per_100_steps")

                if (count + 1) % 250 == 0:
                    print('Avg metrics : ')
                    pprint.pprint(self.sess.run(self.model.stats_ops), width=1)

                count = count + 1
            stop_time = time.time()
            step_sec = stop_time - start_time
            if self.valid_iterator.iter_over == True:
                # print('\nIteration over')
                self.valid_iterator.Reset()
                time.sleep(5)

        print('\nAvg metrics for epoch : ')
        metrics = self.sess.run(self.model.stats_ops)
        pprint.pprint(metrics, width=1)
        if (metrics['avgDice_score'] > self.numKeeper.counts['avgDice_score']):
            self.numKeeper.counts['avgDice_score'] = metrics['avgDice_score']
            self.numKeeper.UpdateCounts(self.summary_manager.counts)
            print('Saving best model for all classes!')
            self.SaveModel(
                os.path.join(self.conf.output_dir, self.conf.run_name,
                             'best_model', 'latest.ckpt'))

        if (metrics['Dice_class_1'] > self.numKeeper.counts['Dice_class_1']):
            self.numKeeper.counts['Dice_class_1'] = metrics['Dice_class_1']
            self.numKeeper.UpdateCounts(self.summary_manager.counts)
            print('Saving best model for class 1!')
            self.SaveModel(
                os.path.join(self.conf.output_dir, self.conf.run_name,
                             'best_model_class1', 'latest.ckpt'))

        if (metrics['Dice_class_2'] > self.numKeeper.counts['Dice_class_2']):
            self.numKeeper.counts['Dice_class_2'] = metrics['Dice_class_2']
            self.numKeeper.UpdateCounts(self.summary_manager.counts)
            print('Saving best model for class 2!')
            self.SaveModel(
                os.path.join(self.conf.output_dir, self.conf.run_name,
                             'best_model_class2', 'latest.ckpt'))

        if (metrics['Dice_class_3'] > self.numKeeper.counts['Dice_class_3']):
            self.numKeeper.counts['Dice_class_3'] = metrics['Dice_class_3']
            self.numKeeper.UpdateCounts(self.summary_manager.counts)
            print('Saving best model for class 3!')
            self.SaveModel(
                os.path.join(self.conf.output_dir, self.conf.run_name,
                             'best_model_class3', 'latest.ckpt'))

        print('Current best average Dice: ' +
              str(self.numKeeper.counts['avgDice_score']))
        summary = self.sess.run(self.model.summary_ops['1epoch'],
                                feed_dict=feed)
        self.sess.run(self.model.reset_ops)
        self.summary_manager.AddSummary(summary, "valid", "per_epoch")
        summary = None
Exemple #9
0
    def Fit(self, steps=1000):
        self.train_iterator.Reset()
        self.summary_manager.mode = 'train'
        time.sleep(5.0)
        feed = None

        train_ops = [
            self.model.logits,
            self.model.summary_ops['1step'],
            self.model.inference_ops,
            self.model.accumulate_ops,
            self.model.train_op,
        ]

        count = 0
        step_sec = 0
        if steps <= 0:
            steps = self.train_iterator.steps
        while (count < steps):
            start_time = time.time()

            # fetch inputs batches and verify if they are numpy.ndarray and run all the ops
            # g_time = time.time()
            input_batch, target_batch, weight_batch, batch_class_weights, _ = self.train_iterator.GetNextBatch(
            )
            # print("time taken to get a batch : " + str(time.time()-g_time) + 's')

            if type(input_batch) is np.ndarray:

                feed = {
                    self.model.inputs: input_batch,
                    self.model.targets: target_batch,
                    self.model.weight_maps: weight_batch,
                    self.model.batch_class_weights: batch_class_weights,
                    self.model.is_training: True
                }
                input_batch, target_batch, weight_batch, batch_class_weights = None, None, None, None
                # i_time = time.time()
                outputs, summary, _, _, _ = self.sess.run(train_ops,
                                                          feed_dict=feed)
                # print("time taken to for inference: " + str(time.time()-i_time) + 's')
                # print("\n")
                self.summary_manager.AddSummary(summary, "train", "per_step")
                progress_bar(count % steps + 1, steps, step_sec)

                outputs, summary = None, None

                # add summaries regularly for every 100 steps or no. of steps to finish an epoch
                if self.train_iterator.steps <= 100:
                    save_step = self.train_iterator.steps // 2
                else:
                    save_step = 100

                if (count + 1) % save_step == 0:
                    summary = self.sess.run(self.model.summary_ops['100steps'],
                                            feed_dict=feed)
                    self.summary_manager.AddSummary(summary, "train",
                                                    "per_100_steps")

                if (count + 1) % 250 == 0:
                    print('Avg metrics : ')
                    pprint.pprint(self.sess.run(self.model.stats_ops), width=1)
                count = count + 1

            stop_time = time.time()
            step_sec = stop_time - start_time
            if self.train_iterator.iter_over == True:
                # print('\nIteration over')
                self.train_iterator.Reset()
                time.sleep(4)

        print('\nAvg metrics for epoch : ')
        pprint.pprint(self.sess.run(self.model.stats_ops), width=1)
        summary = self.sess.run(self.model.summary_ops['1epoch'],
                                feed_dict=feed)
        self.sess.run(self.model.reset_ops)
        self.summary_manager.AddSummary(summary, "train", "per_epoch")
        summary = None