Пример #1
0
def test(epoch, net, testloader, device, num_samples, best_loss):
    net.eval()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, _ in testloader:
            x = x.to(device)
            x_q = sample(net, m=64, n_ch=3, im_w=32, im_h=32, K=100, device=device)
            loss = net(x_q).mean() - net(x).mean()
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs('ckpts', exist_ok=True)
        torch.save(state, 'ckpts/best_ebm.pth.tar')
        best_loss = loss_meter.avg

    # Save samples and data
    images = sample(net, m=64, n_ch=3, im_w=32, im_h=32, K=100, device=device)
    os.makedirs('ebm_samples', exist_ok=True)
    images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)
    torchvision.utils.save_image(images_concat, 'ebm_samples/epoch_{}.png'.format(epoch))
    
    return best_loss
Пример #2
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
                
            optimizer.step()

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))

            scheduler.step(global_step)
            global_step += x.size(0)
Пример #3
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, cond_x in trainloader:
            x , cond_x = x.to(device), cond_x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, cond_x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
    
    print('Saving...')
    state = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, 'savemodel/cINN/checkpoint_' + str(epoch) + '.tar')
Пример #4
0
def evaluate(test_loader,
             model,
             criterion,
             n_iter=-1,
             verbose=False,
             device='cuda'):
    """
    Standard evaluation loop.
    """
    loss_meter = util.AverageMeter()

    # switch to evaluate mode
    model.train()

    with torch.no_grad():
        #end = time.time()
        #bpd = 0
        for i, (x, target) in enumerate(test_loader):
            # early stop
            if i >= 100: break

            x = x.to('cuda')

            z, sldj = model(x, reverse=False)
            loss = criterion(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            bpd = util.bits_per_dim(x, loss_meter.avg)

        return bpd
Пример #5
0
def test(epoch, net, testloader, device, loss_fn, num_samples, save_dir):
    global best_loss
    net.eval()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, _ in testloader:
            x = x.to(device)
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs('save', exist_ok=True)
        torch.save(state, 'save/best.pth.tar')
        best_loss = loss_meter.avg

    # Save samples and data
    images = sample(net, num_samples, device)
    os.makedirs(save_dir, exist_ok=True)
    images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)
    torchvision.utils.save_image(images_concat, os.path.join(save_dir, 'epoch_{}.png'.format(epoch)))
Пример #6
0
def test(epoch, net, testloader, device, loss_fn, num_samples,
         experiment_folder):
    global best_loss
    net.eval()
    loss_meter = util.AverageMeter()
    correct_class = 0
    correct_domain = 0
    save = False
    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, y, d, yd in testloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            z1, z2 = net(x)
            loss2 = loss_fn(z2, d.argmax(dim=1))
            loss1 = loss_fn(z1, y.argmax(dim=1))
            loss_meter.update(loss2.item(), x.size(0))
            loss_meter.update(loss1.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
            values_class, pred_class = torch.max(z1, 1)
            values_domain, pred_domain = torch.max(z2, 1)

            correct_class += pred_class.eq(y.argmax(dim=1)).sum().item()
            correct_domain += pred_domain.eq(d.argmax(dim=1)).sum().item()

            kappa_class = cohen_kappa_score(y.argmax(dim=1).cpu().numpy(),
                                            pred_class.cpu().numpy(),
                                            labels=None,
                                            weights=None)
            kappa_domain = cohen_kappa_score(d.argmax(dim=1).cpu().numpy(),
                                             pred_domain.cpu().numpy(),
                                             labels=None,
                                             weights=None)
            print("kappa class and domain", kappa_class, kappa_domain)

    # Save checkpoint
    if loss_meter.avg < best_loss:
        save = True
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs(experiment_folder + 'ckpts', exist_ok=True)
        torch.save(state, experiment_folder + 'ckpts/best.pth.tar')
        best_loss = loss_meter.avg

    accuracy_class = correct_class * 100. / len(testloader.dataset)
    accuracy_domain = correct_domain * 100. / len(testloader.dataset)

    print('test accuracy class', accuracy_class)
    print('test accuracy domain', accuracy_domain)

    return accuracy_class, accuracy_domain, save, kappa_class, kappa_domain
Пример #7
0
def test(epoch, net, testloader, device, loss_fn, mode, experiment_folder):
    global best_loss
    net.eval()
    loss_meter = util.AverageMeter()
    correct = 0
    correct_0 = 0
    correct_1 = 0
    correct_array = torch.zeros(10)
    wrong_array = torch.zeros(10)
    update = False

    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, y, d, yd in testloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            z = net(x)
            loss = loss_fn(z, y.argmax(dim=1))
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
            values, pred = torch.max(z, 1)

            correct += pred.eq(y.argmax(dim=1)).sum().item()
            for img in range(y.size(0)):
                if pred.eq(y.argmax(dim=1))[img].item():
                    if y.argmax(dim=1)[img].item() == 0:
                        correct_0 += 1
                    else:
                        correct_1 += 1

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs(experiment_folder + args.mode + "_" +
                    str(args.temperature) + "_" + 'ckpts',
                    exist_ok=True)
        torch.save(
            state, experiment_folder + args.mode + "_" +
            str(args.temperature) + "_" + 'ckpts/best.pth.tar')
        best_loss = loss_meter.avg
        update = True

    accuracy = correct * 100. / len(testloader.dataset)
    print(correct)
    print('val accuracy', accuracy)
    print(correct_array)
    print(wrong_array)

    return accuracy, update
Пример #8
0
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()

            progress_bar.set_postfix(loss=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

        print('\ntrain loss = ', loss_meter.avg)
        print('train pbd = ', util.bits_per_dim(x, loss_meter.avg), '\n')
Пример #9
0
def test(epoch, net, testloader, device, loss_fn, mode='color'):
    global best_loss
    net.eval()
    loss_meter = util.AverageMeter()

    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, x_cond in testloader:
            x, x_cond = x.to(device), x_cond.to(device)
            z, sldj = net(x, x_cond, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs('ckpts', exist_ok=True)
        torch.save(state, 'ckpts/best.pth.tar')
        best_loss = loss_meter.avg
    origin_img, gray_img = next(iter(testloader))
    B = gray_img.shape[0]
    # Save samples and data
    images = sample(net, gray_img, device)
    os.makedirs('samples', exist_ok=True)
    os.makedirs('ref_pics', exist_ok=True)
    if mode == 'sketch':
        gray_img = (~gray_img.type(torch.bool)).type(torch.float)
    images_concat = torchvision.utils.make_grid(images,
                                                nrow=int(B**0.5),
                                                padding=2,
                                                pad_value=255)
    origin_concat = torchvision.utils.make_grid(origin_img,
                                                nrow=int(B**0.5),
                                                padding=2,
                                                pad_value=255)
    gray_concat = torchvision.utils.make_grid(gray_img,
                                              nrow=int(B**0.5),
                                              padding=2,
                                              pad_value=255)

    torchvision.utils.save_image(images_concat,
                                 'samples/epoch_{}.png'.format(epoch))
    torchvision.utils.save_image(origin_concat,
                                 'ref_pics/origin_{}.png'.format(epoch))
    torchvision.utils.save_image(gray_concat,
                                 'ref_pics/gray_{}.png'.format(epoch))
Пример #10
0
def test(epoch, net, testloader, device, loss_fn, num_samples, label):
    global best_loss
    net.eval()
    loss_meter = util.AverageMeter()
    with torch.no_grad():
        with tqdm(total=len(testloader.dataset)) as progress_bar:
            for x, _ in testloader:
                #if True: # label.endswith('test'):
                #    print(x.shape, x.type(), x.min(), x.max())
                x = x.to(device)
                z, sldj = net(x, reverse=False)
                loss = loss_fn(z, sldj)
                loss_meter.update(loss.item(), x.size(0))
                progress_bar.set_postfix(loss=loss_meter.avg,
                                         bpd=util.bits_per_dim(x, loss_meter.avg))
                progress_bar.update(x.size(0))
            print('\n' + label + ' loss = ', loss_meter.avg)
            print(label + ' pbd = ', util.bits_per_dim(x, loss_meter.avg), '\n')

    # Save checkpoint
    if label == 'val' and loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'val_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs('ckpts', exist_ok=True)
        torch.save(state, 'ckpts/best.pth.tar')
        best_loss = loss_meter.avg

    # Save samples and data
    images = sample(net, num_samples, device)
    os.makedirs('samples', exist_ok=True)
    images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)
    torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch))
Пример #11
0
def test(epoch, net, testloader, device, loss_fn, num_samples, mode,
         experiment_folder, full):
    net.eval()
    loss_meter = util.AverageMeter()
    correct = 0
    total_0 = 0
    total_1 = 0
    correct_0 = 0
    correct_1 = 0
    auc = None

    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, y, d, yd in testloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            z = net(x)

            loss = loss_fn(z, y.argmax(dim=1))
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
            values, pred = torch.max(z, 1)

            correct += pred.eq(y.argmax(dim=1)).sum().item()
            if full:
                auc = roc_auc_score(y.argmax(dim=1).cpu().numpy(),
                                    pred.cpu().numpy(),
                                    labels=[0, 1])
            for img in range(y.size(0)):
                if pred.eq(y.argmax(dim=1))[img].item():
                    if y.argmax(dim=1)[img].item() == 0:
                        correct_0 += 1
                    else:
                        correct_1 += 1

    print("total", len(testloader.dataset))
    print(correct)
    accuracy_from_best_model = correct * 100. / len(testloader.dataset)
    print(accuracy_from_best_model)

    print(correct_0)
    print(correct_1)
    total_per_class = len(testloader.dataset) / 2
    print("Uninfected accuracy", correct_0 * 100. / total_per_class)
    print("Parasitized accuracy", correct_1 * 100. / total_per_class)
    print("testing on", len(testloader.dataset))
    return accuracy_from_best_model, auc
Пример #12
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    correct_class = 0
    correct_domain = 0

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, y, d, yd in trainloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            optimizer.zero_grad()
            z1, z2 = net(x)
            loss2 = loss_fn(z2, d.argmax(dim=1))
            loss1 = loss_fn(z1, y.argmax(dim=1))
            loss_meter.update(loss2.item(), x.size(0))
            loss_meter.update(loss1.item(), x.size(0))
            loss = loss1 + loss2
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
            values_class, pred_class = torch.max(z1, 1)
            values_domain, pred_domain = torch.max(z2, 1)

            correct_class += pred_class.eq(y.argmax(dim=1)).sum().item()
            correct_domain += pred_domain.eq(d.argmax(dim=1)).sum().item()
    accuracy_class = correct_class * 100. / len(trainloader.dataset)
    accuracy_domain = correct_domain * 100. / len(trainloader.dataset)

    print('train accuracy class', accuracy_class)
    print('train accuracy domain', accuracy_domain)

    return accuracy_class, accuracy_domain
Пример #13
0
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm, mode):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    correct = 0

    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, y, d, yd in trainloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            optimizer.zero_grad()
            z = net(x)
            if mode == 'domain':
                loss = loss_fn(z, d.argmax(dim=1))
            else:
                loss = loss_fn(z, y.argmax(dim=1))
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                util.clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)
            values, pred = torch.max(z, 1)

            if mode == 'label':
                correct += pred.eq(y.argmax(dim=1)).sum().item()
            else:
                correct += pred.eq(d.argmax(dim=1)).sum().item()
    accuracy = correct * 100. / len(trainloader.dataset)

    print('train accuracy', accuracy)

    return accuracy
Пример #14
0
def finetune_centroids(train_loader,
                       student,
                       teacher,
                       criterion,
                       optimizer,
                       epoch=0,
                       n_iter=-1,
                       verbose=False):
    """
    Student/teacher distillation training loop.

    Remarks:
        - The student has to be in train() mode as this function will not
          automatically switch to it for finetuning purposes
    """

    #student.train()
    losses = util.AverageMeter()

    for i, (input, target) in enumerate(train_loader):
        # early stop
        if i >= n_iter: break

        print('Epoch {}starts'.format(i))

        # cuda
        input = input.cuda()

        teacher_, _t = teacher(input)
        student_, _s = student(input)
        student_probs = F.softmax(student_, dim=1)
        teacher_probs = F.softmax(teacher_, dim=1)
        loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean')
        losses.update(loss.item(), input.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        bpd = util.bits_per_dim(input, losses.avg)

    return losses.avg
Пример #15
0
def train_full(epoch, net, trainloader, device, optimizer, scheduler):
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            x_q = sample(net, m=64, n_ch=3, im_w=32, im_h=32, K=100, device=device)
            loss = net(x_q).mean() - net(x).mean()
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
 
            optimizer.step()
            if scheduler != None:
                scheduler.step()

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
Пример #16
0
def train(epochs, net, trainloader, device, optimizer, scheduler, loss_fn,
          max_grad_norm):
    global global_step

    net.train()
    loss_meter = util.AverageMeter()
    evaluator = EvaluationModel()
    test_conditions = get_test_conditions(os.path.join('test.json')).to(device)
    new_test_conditions = get_test_conditions(
        os.path.join('new_test.json')).to(device)
    best_score = 0
    new_best_score = 0

    for epoch in range(1, epochs + 1):
        print('\nEpoch: ', epoch)
        with tqdm(total=len(trainloader.dataset)) as progress_bar:
            for x, cond_x in trainloader:
                x, cond_x = x.to(device, dtype=torch.float), cond_x.to(
                    device, dtype=torch.float)
                optimizer.zero_grad()
                z, sldj = net(x, cond_x, reverse=False)
                loss = loss_fn(z, sldj)
                wandb.log({'loss': loss})
                # print('loss: ',loss)
                loss_meter.update(loss.item(), x.size(0))
                # wandb.log({'loss_meter',loss_meter})
                loss.backward()
                if max_grad_norm > 0:
                    util.clip_grad_norm(optimizer, max_grad_norm)
                optimizer.step()
                # scheduler.step(global_step)

                progress_bar.set_postfix(nll=loss_meter.avg,
                                         bpd=util.bits_per_dim(
                                             x, loss_meter.avg),
                                         lr=optimizer.param_groups[0]['lr'])
                progress_bar.update(x.size(0))
                global_step += x.size(0)

        net.eval()
        with torch.no_grad():
            gen_imgs = sample(net, test_conditions, device)
        score = evaluator.eval(gen_imgs, test_conditions)
        wandb.log({'score': score})
        if score > best_score:
            best_score = score
            best_model_wts = copy.deepcopy(net.state_dict())
            torch.save(
                best_model_wts,
                os.path.join('weightings/test',
                             f'epoch{epoch}_score{score:.2f}.pt'))

        with torch.no_grad():
            new_gen_imgs = sample(net, new_test_conditions, device)
        new_score = evaluator.eval(new_gen_imgs, new_test_conditions)
        wandb.log({'new_score': new_score})
        if new_score > new_best_score:
            new_best_score = score
            new_best_model_wts = copy.deepcopy(net.state_dict())
            torch.save(
                best_model_wts,
                os.path.join('weightings/new_test',
                             f'epoch{epoch}_score{score:.2f}.pt'))
        save_image(gen_imgs,
                   os.path.join('results/test', f'epoch{epoch}.png'),
                   nrow=8,
                   normalize=True)
        save_image(new_gen_imgs,
                   os.path.join('results/new_test', f'epoch{epoch}.png'),
                   nrow=8,
                   normalize=True)
Пример #17
0
def test(epoch, net, testloader, device, loss_fn, num_samples, in_channels, base_path):
    global best_loss
    global mean_conds
    global hists
    net.eval()
    loss_meter = util.AverageMeter()
    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, _ in testloader:
            x = x.to(device)
            with torch.no_grad():
                z, sldj = net(x, reverse=False)
                loss = loss_fn(x, z, sldj)
                loss_meter.update(loss.item(), x.size(0))
                progress_bar.set_postfix(loss=loss_meter.avg,
                                         bpd=util.bits_per_dim(x, loss_meter.avg))
                progress_bar.update(x.size(0))

    conds = []
    # for i in trange(x.shape[0]):
    #     jac = jacobian(net, x[i:i+1, ...])[0]
    #     side = jac.shape[2]
    #     channels = jac.shape[1]
    #     jac = jac.reshape((channels * side * side, channels * side * side))
    #     cond = np.linalg.cond(jac.cpu().numpy())
    #     conds.append(cond)
    # mean_conds.append(np.mean(conds))
    # print(f"Mean of Condition Numbers: {mean_conds[-1]}")

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        ckpt_path = base_path / 'ckpts'
        ckpt_path.mkdir(exist_ok=True)
        best_path_ckpt = ckpt_path / 'best.pth.tar'
        torch.save(state, best_path_ckpt)
        best_loss = loss_meter.avg

    # Save samples and data
    images = sample(net, num_samples, device, in_channels)
    if images.shape[1] == 2:
        images = images[:, :1, :, :]
    if images.shape[1] == 6:
        images = images[:, :3, :, :]
    image_vals = images.detach().cpu().numpy().flatten()
    hist = np.histogram(image_vals, bins=100)
    hists.append(hist)
    hists_path = base_path / 'hists.pkl'
    with hists_path.open('wb') as f:
        pickle.dump(hists, f)
    samples_path = base_path / 'samples'
    samples_path.mkdir(exist_ok=True)
    epoch_path = samples_path / f'epoch_{epoch}.png'
    conds_path = base_path / 'mean_conds.npy'
    images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)
    torchvision.utils.save_image(images_concat, epoch_path)
    np.save(conds_path, np.array(mean_conds))
def test(epoch, net, testloader, device, loss_fn, num_samples, mode,
         experiment_folder):

    net.eval()
    loss_meter = util.AverageMeter()
    correct = 0
    total_0 = 0
    total_1 = 0
    correct_0 = 0
    correct_1 = 0

    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, y, d, yd in testloader:
            x, y, d, yd = x.to(device), y.to(device), d.to(device), yd.to(
                device)
            z = net(x)
            if mode == 'domain':
                loss = loss_fn(z, d.argmax(dim=1))
            else:
                loss = loss_fn(z, y.argmax(dim=1))
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=util.bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))
            values, pred = torch.max(z, 1)

            if mode == 'label':
                correct += pred.eq(y.argmax(dim=1)).sum().item()
                for img in range(y.size(0)):
                    if pred.eq(y.argmax(dim=1))[img].item():
                        if y.argmax(dim=1)[img].item() == 0:
                            correct_0 += 1
                        else:
                            correct_1 += 1
                kappa = None

                kappa = cohen_kappa_score(y.argmax(dim=1).cpu().numpy(),
                                          pred.cpu().numpy(),
                                          labels=None,
                                          weights=None)
                print("kappa", kappa)

            else:
                correct += pred.eq(d.argmax(dim=1)).sum().item()
                print(
                    "kappa",
                    cohen_kappa_score(d.argmax(dim=1).cpu().numpy(),
                                      pred.cpu().numpy(),
                                      labels=None,
                                      weights=None))
                kappa = cohen_kappa_score(d.argmax(dim=1).cpu().numpy(),
                                          pred.cpu().numpy(),
                                          labels=None,
                                          weights=None)

    print("total", len(testloader.dataset))
    print(correct)
    accuracy_from_best_model = correct * 100. / len(testloader.dataset)
    print(accuracy_from_best_model)

    if mode == 'label':
        print(correct_0)
        print(correct_1)
        total_per_class = len(testloader.dataset) / 2
        print("Uninfected accuracy", correct_0 * 100. / total_per_class)
        print("Parasitized accuracy", correct_1 * 100. / total_per_class)
    return accuracy_from_best_model, kappa
Пример #19
0
    #testset = imagenet_val(transform)
    testloader = data.DataLoader(testset,
                                 batch_size=64,
                                 shuffle=False,
                                 num_workers=8)
    loss_fn = util.NLLLoss().to(device)
    loss_meter = util.AverageMeter()
    bpd_sum = 0
    n = 0
    for x, _ in testloader:
        #x = x.to(device)
        z, sldj = net(x, reverse=False)
        loss = loss_fn(z, sldj)
        loss_meter.update(loss.item(), x.size(0))
        n += 1
        bpd_sum += util.bits_per_dim(x, loss_meter.avg)
        #print(util.bits_per_dim(x, loss_meter.avg))
        #print(bpd_sum/n)
    print(bpd_sum / n)

for i in range(3):
    net = Glow(num_channels=512, num_levels=3, num_steps=16)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'
    #net.to(device)
    if i == 0:
        net.load_state_dict({
            k.replace('module.', ''): v
            for k, v in torch.load("ckpts/-2.pth.tar")['net'].items()
        })
    if i == 1: