예제 #1
0
def train(train_dloader_l,
          train_dloader_u,
          model,
          criterion_l,
          criterion_u,
          optimizer,
          epoch,
          writer,
          alpha,
          zca_mean=None,
          zca_components=None):
    # some records
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_supervised = AverageMeter()
    losses_unsupervised = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    epoch_length = len(train_dloader_u)
    i = 0
    for (image_l, label_l), (image_u, _) in zip(cycle(train_dloader_l),
                                                train_dloader_u):
        if image_l.size(0) != image_u.size(0):
            bt_size = min(image_l.size(0), image_u.size(0))
            image_l = image_l[0:bt_size]
            image_u = image_u[0:bt_size]
            label_l = label_l[0:bt_size]
        else:
            bt_size = image_l.size(0)
        data_time.update(time.time() - end)
        image_l = image_l.float().cuda()
        image_u = image_u.float().cuda()
        label_l = label_l.long().cuda()
        if args.zca:
            image_l = apply_zca(image_l, zca_mean, zca_components)
            image_u = apply_zca(image_u, zca_mean, zca_components)
        if args.mixup:
            mixed_image_l, label_a, label_b, lam = mixup_data(
                image_l, label_l, args.mas)
            cls_result_l = model(mixed_image_l)
            loss_supervised = mixup_criterion(criterion_l, cls_result_l,
                                              label_a, label_b, lam)
            # here label_u_approx is not with any grad
            with torch.no_grad():
                label_u_approx = torch.softmax(model(image_u), dim=1)
            mixed_image_u, label_a_approx, label_b_approx, lam = mixup_data(
                image_u, label_u_approx, args.mau)
            cls_result_u = model(mixed_image_u)
            cls_result_u = torch.log_softmax(cls_result_u, dim=1)
            label_u_approx_mixup = lam * label_a_approx + (
                1 - lam) * label_b_approx
            loss_unsupervised = -1 * torch.mean(
                torch.sum(label_u_approx_mixup * cls_result_u, dim=1))
            loss = loss_supervised + alpha * loss_unsupervised
        elif args.manifold_mixup:
            cls_result_l, label_a, label_b, lam = model(
                image_l,
                mixup_alpha=args.mas,
                label=label_l,
                manifold_mixup=True,
                mixup_layer_list=args.mll)
            loss_supervised = mixup_criterion(criterion_l, cls_result_l,
                                              label_a, label_b, lam)
            # here label_u_approx is not with any grad
            with torch.no_grad():
                label_u_approx = torch.softmax(model(image_u), dim=1)
            cls_result_u, label_a_approx, label_b_approx, lam = model(
                image_u,
                mixup_alpha=args.mas,
                label=label_u_approx,
                manifold_mixup=True,
                mixup_layer_list=args.mll)
            cls_result_u = torch.softmax(cls_result_u, dim=1)
            label_u_approx_mixup = lam * label_a_approx + (
                1 - lam) * label_b_approx
            loss_unsupervised = criterion_u(cls_result_u, label_u_approx_mixup)
            loss = loss_supervised + 10 * alpha * loss_unsupervised
        else:
            cls_result_l = model(image_l)
            loss = criterion_l(cls_result_l, label_l)
            loss_supervised = loss.detach()
            loss_unsupervised = torch.zeros(loss.size())
        loss.backward()
        losses.update(float(loss.item()), bt_size)
        losses_supervised.update(float(loss_supervised.item()), bt_size)
        losses_unsupervised.update(float(loss_unsupervised.item()), bt_size)
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'Cls Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t' \
                         'Regularization Loss {reg_loss.val:.4f} ({reg_loss.avg:.4f})\t' \
                         'Total Loss {total_loss.val:.4f} ({total_loss.avg:.4f})\t'.format(
                epoch, i + 1, epoch_length, batch_time=batch_time, data_time=data_time,
                cls_loss=losses_supervised, reg_loss=losses_unsupervised, total_loss=losses)
            print(train_text)
        i += 1
    writer.add_scalar(tag="Train/cls_loss",
                      scalar_value=losses_supervised.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/reg_loss",
                      scalar_value=losses_unsupervised.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/total_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    return losses.avg
예제 #2
0
def test(valid_dloader,
         test_dloader,
         model,
         criterion,
         epoch,
         writer,
         num_classes,
         zca_mean=None,
         zca_components=None):
    model.eval()
    # calculate result for valid dataset
    losses = AverageMeter()
    all_score = []
    all_label = []
    for i, (image, label) in enumerate(valid_dloader):
        image = image.float().cuda()
        if args.zca:
            image = apply_zca(image, zca_mean, zca_components)
        label = label.long().cuda()
        with torch.no_grad():
            cls_result = model(image)
        label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_(
            1, label.view(-1, 1), 1)
        loss = criterion(cls_result, label)
        losses.update(float(loss.item()), image.size(0))
        # here we add the all score and all label into one list
        all_score.append(torch.softmax(cls_result, dim=1))
        # turn label into one-hot code
        all_label.append(label_onehot)
    writer.add_scalar(tag="Valid/cls_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    all_score = torch.cat(all_score, dim=0).detach()
    all_label = torch.cat(all_label, dim=0).detach()
    _, y_true = torch.topk(all_label, k=1, dim=1)
    _, y_pred = torch.topk(all_score, k=5, dim=1)
    # calculate accuracy by hand
    top_1_accuracy = float(
        torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
    top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0)
    writer.add_scalar(tag="Valid/top 1 accuracy",
                      scalar_value=top_1_accuracy,
                      global_step=epoch + 1)
    if args.dataset == "Cifar100":
        writer.add_scalar(tag="Valid/top 5 accuracy",
                          scalar_value=top_5_accuracy,
                          global_step=epoch + 1)
    # calculate result for test dataset
    losses = AverageMeter()
    all_score = []
    all_label = []
    for i, (image, label) in enumerate(test_dloader):
        image = image.float().cuda()
        if args.zca:
            image = apply_zca(image, zca_mean, zca_components)
        label = label.long().cuda()
        with torch.no_grad():
            cls_result = model(image)
        label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_(
            1, label.view(-1, 1), 1)
        loss = criterion(cls_result, label)
        losses.update(float(loss.item()), image.size(0))
        # here we add the all score and all label into one list
        all_score.append(torch.softmax(cls_result, dim=1))
        # turn label into one-hot code
        all_label.append(label_onehot)
    writer.add_scalar(tag="Test/cls_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    all_score = torch.cat(all_score, dim=0).detach()
    all_label = torch.cat(all_label, dim=0).detach()
    _, y_true = torch.topk(all_label, k=1, dim=1)
    _, y_pred = torch.topk(all_score, k=5, dim=1)
    # calculate accuracy by hand
    top_1_accuracy = float(
        torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
    top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0)
    writer.add_scalar(tag="Test/top 1 accuracy",
                      scalar_value=top_1_accuracy,
                      global_step=epoch + 1)
    if args.dataset == "Cifar100":
        writer.add_scalar(tag="Test/top 5 accuracy",
                          scalar_value=top_5_accuracy,
                          global_step=epoch + 1)

    return losses.avg
예제 #3
0
def test(test_dloader, model, elbo_criterion, epoch, writer, discrete_latent_dim):
    continuous_kl_losses = AverageMeter()
    discrete_kl_losses = AverageMeter()
    mse_losses = AverageMeter()
    elbo_losses = AverageMeter()
    model.eval()
    all_score = []
    all_label = []

    for i, (image, label) in enumerate(test_dloader):
        image = image.float().cuda()
        label = label.long().cuda()
        label_onehot = torch.zeros(label.size(0), discrete_latent_dim).cuda().scatter_(1, label.view(-1, 1), 1)
        batch_size = image.size(0)
        with torch.no_grad():
            reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model(image)
        reconstruct_loss, continuous_kl_loss, discrete_kl_loss = elbo_criterion(image, reconstruction, norm_mean,
                                                                                norm_log_sigma, disc_log_alpha)
        mse_loss = F.mse_loss(torch.sigmoid(reconstruction.detach()), image.detach(),
                              reduction="sum") / (
                           2 * image.size(0) * (args.x_sigma ** 2))
        mse_losses.update(float(mse_loss), image.size(0))
        all_score.append(torch.exp(disc_log_alpha))
        all_label.append(label_onehot)
        continuous_kl_losses.update(float(continuous_kl_loss.item()), batch_size)
        discrete_kl_losses.update(float(discrete_kl_loss.item()), batch_size)
        elbo_losses.update(float(mse_loss + 0.01*(continuous_kl_loss + discrete_kl_loss)), image.size(0))

    writer.add_scalar(tag="Test/KL(q(z|X)||p(z))", scalar_value=continuous_kl_losses.avg, global_step=epoch + 1)
    writer.add_scalar(tag="Test/KL(q(y|X)||p(y))", scalar_value=discrete_kl_losses.avg, global_step=epoch + 1)
    writer.add_scalar(tag="Test/log(p(X|z,y))", scalar_value=mse_losses.avg, global_step=epoch + 1)
    writer.add_scalar(tag="Test/ELBO", scalar_value=elbo_losses.avg, global_step=epoch + 1)
    all_score = torch.cat(all_score, dim=0).detach()
    all_label = torch.cat(all_label, dim=0).detach()
    _, y_true = torch.topk(all_label, k=1, dim=1)
    _, y_pred = torch.topk(all_score, k=5, dim=1)
    # calculate accuracy by hand
    test_top_1_accuracy = float(torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
    test_top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0)
    writer.add_scalar(tag="Test/top1 accuracy", scalar_value=test_top_1_accuracy, global_step=epoch + 1)
    if args.dataset == "Cifar100":
        writer.add_scalar(tag="Test/top 5 accuracy", scalar_value=test_top_5_accuracy, global_step=epoch + 1)
    if epoch % args.reconstruct_freq == 0:
        with torch.no_grad():
            image = utils.make_grid(image[:4, ...], nrow=2)
            reconstruct_image = utils.make_grid(torch.sigmoid(reconstruction[:4, ...]), nrow=2)
        writer.add_image(tag="Test/Raw_Image", img_tensor=image, global_step=epoch + 1)
        writer.add_image(tag="Test/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1)

    return test_top_1_accuracy, test_top_5_accuracy
예제 #4
0
파일: train.py 프로젝트: yuanwei0908/KD3A
def test(target_domain,
         source_domains,
         test_dloader_list,
         model_list,
         classifier_list,
         epoch,
         writer,
         num_classes=126,
         top_5_accuracy=True):
    source_domain_losses = [AverageMeter() for i in source_domains]
    target_domain_losses = AverageMeter()
    task_criterion = nn.CrossEntropyLoss().cuda()
    for model in model_list:
        model.eval()
    for classifier in classifier_list:
        classifier.eval()
    # calculate loss, accuracy for target domain
    tmp_score = []
    tmp_label = []
    test_dloader_t = test_dloader_list[0]
    for _, (image_t, label_t) in enumerate(test_dloader_t):
        image_t = image_t.cuda()
        label_t = label_t.long().cuda()
        with torch.no_grad():
            output_t = classifier_list[0](model_list[0](image_t))
        label_onehot_t = torch.zeros(label_t.size(0),
                                     num_classes).cuda().scatter_(
                                         1, label_t.view(-1, 1), 1)
        task_loss_t = task_criterion(output_t, label_t)
        target_domain_losses.update(float(task_loss_t.item()), image_t.size(0))
        tmp_score.append(torch.softmax(output_t, dim=1))
        # turn label into one-hot code
        tmp_label.append(label_onehot_t)
    writer.add_scalar(tag="Test/target_domain_{}_loss".format(target_domain),
                      scalar_value=target_domain_losses.avg,
                      global_step=epoch + 1)
    tmp_score = torch.cat(tmp_score, dim=0).detach()
    tmp_label = torch.cat(tmp_label, dim=0).detach()
    _, y_true = torch.topk(tmp_label, k=1, dim=1)
    _, y_pred = torch.topk(tmp_score, k=5, dim=1)
    top_1_accuracy_t = float(
        torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
    writer.add_scalar(tag="Test/target_domain_{}_accuracy_top1".format(
        target_domain).format(target_domain),
                      scalar_value=top_1_accuracy_t,
                      global_step=epoch + 1)
    if top_5_accuracy:
        top_5_accuracy_t = float(
            torch.sum(y_true == y_pred).item()) / y_true.size(0)
        writer.add_scalar(tag="Test/target_domain_{}_accuracy_top5".format(
            target_domain).format(target_domain),
                          scalar_value=top_5_accuracy_t,
                          global_step=epoch + 1)
        print("Target Domain {} Accuracy Top1 :{:.3f} Top5:{:.3f}".format(
            target_domain, top_1_accuracy_t, top_5_accuracy_t))
    else:
        print("Target Domain {} Accuracy {:.3f}".format(
            target_domain, top_1_accuracy_t))
    # calculate loss, accuracy for source domains
    for s_i, domain_s in enumerate(source_domains):
        tmp_score = []
        tmp_label = []
        test_dloader_s = test_dloader_list[s_i + 1]
        for _, (image_s, label_s) in enumerate(test_dloader_s):
            image_s = image_s.cuda()
            label_s = label_s.long().cuda()
            with torch.no_grad():
                output_s = classifier_list[s_i + 1](model_list[s_i +
                                                               1](image_s))
            label_onehot_s = torch.zeros(label_s.size(0),
                                         num_classes).cuda().scatter_(
                                             1, label_s.view(-1, 1), 1)
            task_loss_s = task_criterion(output_s, label_s)
            source_domain_losses[s_i].update(float(task_loss_s.item()),
                                             image_s.size(0))
            tmp_score.append(torch.softmax(output_s, dim=1))
            # turn label into one-hot code
            tmp_label.append(label_onehot_s)
        writer.add_scalar(tag="Test/source_domain_{}_loss".format(domain_s),
                          scalar_value=source_domain_losses[s_i].avg,
                          global_step=epoch + 1)
        tmp_score = torch.cat(tmp_score, dim=0).detach()
        tmp_label = torch.cat(tmp_label, dim=0).detach()
        _, y_true = torch.topk(tmp_label, k=1, dim=1)
        _, y_pred = torch.topk(tmp_score, k=5, dim=1)
        top_1_accuracy_s = float(
            torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
        writer.add_scalar(
            tag="Test/source_domain_{}_accuracy_top1".format(domain_s),
            scalar_value=top_1_accuracy_s,
            global_step=epoch + 1)
        if top_5_accuracy:
            top_5_accuracy_s = float(
                torch.sum(y_true == y_pred).item()) / y_true.size(0)
            writer.add_scalar(
                tag="Test/source_domain_{}_accuracy_top5".format(domain_s),
                scalar_value=top_5_accuracy_s,
                global_step=epoch + 1)
예제 #5
0
def train(train_dloader, model, criterion, optimizer, epoch, writer, dataset):
    # record the time for loading a data and do backward for a batch
    # also record the loss value
    batch_time = AverageMeter()
    data_time = AverageMeter()
    reconstruct_losses = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    for i, (image, *_) in enumerate(train_dloader):
        data_time.update(time.time() - end)
        image = image.float().cuda()
        image_reconstructed = model(image)
        reconstruct_loss = criterion(image, image_reconstructed)
        # loss = kl_loss
        reconstruct_loss.backward()
        if args.gd_clip_flag:
            torch.nn.utils.clip_grad_value_(model.parameters(), args.gd_clip_value)
        reconstruct_losses.update(float(reconstruct_loss), image.size(0))
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'Reconstruct Loss {reconstruct_loss.val:.4f} ({reconstruct_loss.avg:.4f})\t'.format(
                epoch, i, len(train_dloader), batch_time=batch_time,
                data_time=data_time, reconstruct_loss=reconstruct_losses)
            print(train_text)
    writer.add_scalar(tag="{}_train/reconstruct_loss".format(dataset), scalar_value=reconstruct_losses.avg,
                      global_step=epoch)
    return reconstruct_losses.avg
예제 #6
0
def train(train_dloader_u, train_dloader_l, model, elbo_criterion, cls_criterion, optimizer, epoch, writer,
          discrete_latent_dim):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    kl_inferences = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    # mutual information
    cmi = alpha_schedule(epoch, args.akb, args.cmi)
    dmi = alpha_schedule(epoch, args.akb, args.dmi)
    # elbo part weight
    ew = alpha_schedule(epoch, args.aew, args.ewm)
    # mixup parameters
    kl_beta_c = alpha_schedule(epoch, args.akb, args.kbmc)
    kl_beta_d = alpha_schedule(epoch, args.akb, args.kbmd)
    for i, ((image_l, label_l), (image_u, label_u)) in enumerate(zip(cycle(train_dloader_l), train_dloader_u)):
        if image_l.size(0) != image_u.size(0):
            batch_size = min(image_l.size(0), image_u.size(0))
            image_l = image_l[0:batch_size]
            label_l = label_l[0:batch_size]
            image_u = image_u[0:batch_size]
            label_u = label_u[0:batch_size]
        else:
            batch_size = image_l.size(0)
        data_time.update(time.time() - end)
        # for the labeled part, do classification and mixup
        image_l = image_l.float().cuda()
        label_l = label_l.long().cuda()
        label_onehot_l = torch.zeros(batch_size, discrete_latent_dim).cuda().scatter_(1, label_l.view(-1, 1), 1)
        reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l = model(image_l, disc_label=label_l)
        reconstruct_loss_l, continuous_prior_kl_loss_l, disc_prior_kl_loss_l = elbo_criterion(image_l, reconstruction_l,
                                                                                              norm_mean_l,
                                                                                              norm_log_sigma_l,
                                                                                              disc_log_alpha_l)
        prior_kl_loss_l = kl_beta_c * torch.abs(continuous_prior_kl_loss_l - cmi) + kl_beta_d * torch.abs(
            disc_prior_kl_loss_l - dmi)
        elbo_loss_l = reconstruct_loss_l + prior_kl_loss_l
        disc_posterior_kl_loss_l = cls_criterion(disc_log_alpha_l, label_onehot_l)
        loss_supervised = ew * elbo_loss_l + disc_posterior_kl_loss_l
        loss_supervised.backward()

        # for the unlabeled part, do classification and mixup
        image_u = image_u.float().cuda()
        label_u = label_u.long().cuda()
        reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u = model(image_u)
        # calculate the KL(q(y|X)||p(y|X))
        with torch.no_grad():
            label_smooth_u = torch.zeros(batch_size, discrete_latent_dim).cuda().scatter_(1, label_u.view(-1, 1),
                                                                                          1 - 0.001 - 0.001 / (
                                                                                                  discrete_latent_dim - 1))
            label_smooth_u = label_smooth_u + torch.ones(label_smooth_u.size()).cuda() * 0.001 / (discrete_latent_dim - 1)
            disc_alpha_u = torch.exp(disc_log_alpha_u)
            inference_kl = disc_alpha_u * disc_log_alpha_u - disc_alpha_u * torch.log(label_smooth_u)
        kl_inferences.update(float(torch.sum(inference_kl) / batch_size), batch_size)
        reconstruct_loss_u, continuous_prior_kl_loss_u, disc_prior_kl_loss_u = elbo_criterion(image_u, reconstruction_u,
                                                                                              norm_mean_u,
                                                                                              norm_log_sigma_u,
                                                                                              disc_log_alpha_u)
        prior_kl_loss_u = kl_beta_c * torch.abs(continuous_prior_kl_loss_u - cmi) + kl_beta_d * torch.abs(
            disc_prior_kl_loss_u - dmi)
        elbo_loss_u = reconstruct_loss_u + prior_kl_loss_u
        loss_unsupervised = ew * elbo_loss_u
        loss_unsupervised.backward()
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(
                epoch, i + 1, len(train_dloader_u), batch_time=batch_time, data_time=data_time)
            print(train_text)
    writer.add_scalar(tag="Train/KL_Inference", scalar_value=kl_inferences.avg, global_step=epoch + 1)
    # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images
    if epoch % args.reconstruct_freq == 0:
        with torch.no_grad():
            image = utils.make_grid(image_u[:4, ...], nrow=2)
            reconstruct_image = utils.make_grid(torch.sigmoid(reconstruction_u[:4, ...]), nrow=2)
        writer.add_image(tag="Train/Raw_Image", img_tensor=image, global_step=epoch + 1)
        writer.add_image(tag="Train/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1)
예제 #7
0
def train(train_dloader, model, criterion, optimizer, epoch, writer):
    # some records
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    for i, (image, label) in enumerate(train_dloader):
        data_time.update(time.time() - end)
        image = image.float().cuda()
        label = label.long().cuda()
        cls_result = model(image)
        loss = criterion(cls_result, label)
        loss.backward()
        losses.update(float(loss.item()), image.size(0))
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'Cls Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})'.format(
                epoch, i + 1, len(train_dloader), batch_time=batch_time, data_time=data_time,
                cls_loss=losses)
            print(train_text)
    writer.add_scalar(tag="Train/cls_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    return losses.avg
예제 #8
0
def test(valid_dloader, test_dloader, model, criterion, epoch, writer,
         num_classes):
    model.eval()
    # calculate index for valid dataset
    losses = AverageMeter()
    all_score = []
    all_label = []
    for i, (image, label) in enumerate(valid_dloader):
        image = image.float().cuda()
        label = label.long().cuda()
        with torch.no_grad():
            cls_result = model(image)
        label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_(
            1, label.view(-1, 1), 1)
        loss = criterion(cls_result, label)
        losses.update(float(loss.item()), image.size(0))
        # here we add the all score and all label into one list
        all_score.append(torch.softmax(cls_result, dim=1))
        # turn label into one-hot code
        all_label.append(label_onehot)
    writer.add_scalar(tag="Valid/cls_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    all_score = torch.cat(all_score, dim=0).detach()
    all_label = torch.cat(all_label, dim=0).detach()
    _, y_true = torch.topk(all_label, k=1, dim=1)
    _, y_pred = torch.topk(all_score, k=5, dim=1)
    top_1_accuracy = float(
        torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
    top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0)
    writer.add_scalar(tag="Valid/top 1 accuracy",
                      scalar_value=top_1_accuracy,
                      global_step=epoch + 1)
    if args.dataset == "Cifar100":
        writer.add_scalar(tag="Valid/top 5 accuracy",
                          scalar_value=top_5_accuracy,
                          global_step=epoch + 1)
    # calculate index for test dataset
    losses = AverageMeter()
    all_score = []
    all_label = []
    # don't use roc
    # roc_list = []
    for i, (image, label) in enumerate(test_dloader):
        image = image.float().cuda()
        label = label.long().cuda()
        with torch.no_grad():
            cls_result = model(image)
        label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_(
            1, label.view(-1, 1), 1)
        loss = criterion(cls_result, label)
        losses.update(float(loss.item()), image.size(0))
        # here we add the all score and all label into one list
        all_score.append(torch.softmax(cls_result, dim=1))
        # turn label into one-hot code
        all_label.append(label_onehot)
    writer.add_scalar(tag="Test/cls_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    all_score = torch.cat(all_score, dim=0).detach()
    all_label = torch.cat(all_label, dim=0).detach()
    _, y_true = torch.topk(all_label, k=1, dim=1)
    _, y_pred = torch.topk(all_score, k=5, dim=1)
    # don't use roc auc
    # all_score = all_score.cpu().numpy()
    # all_label = all_label.cpu().numpy()
    # for i in range(num_classes):
    #     roc_list.append(roc_auc_score(all_label[:, i], all_score[:, i]))
    # ap_list.append(average_precision_score(all_label[:, i], all_score[:, i]))
    # calculate accuracy by hand
    top_1_accuracy = float(
        torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0)
    top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0)
    writer.add_scalar(tag="Test/top 1 accuracy",
                      scalar_value=top_1_accuracy,
                      global_step=epoch + 1)
    if args.dataset == "Cifar100":
        writer.add_scalar(tag="Test/top 5 accuracy",
                          scalar_value=top_5_accuracy,
                          global_step=epoch + 1)
    # writer.add_scalar(tag="Test/mean RoC", scalar_value=mean(roc_list), global_step=epoch + 1)
    return top_1_accuracy
예제 #9
0
def train(train_dloader,
          model,
          elbo_criterion,
          optimizer,
          epoch,
          writer,
          kl_disc_criterion=None,
          kl_norm_criterion=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    reconstruct_losses = AverageMeter()
    kl_losses = AverageMeter()
    elbo_losses = AverageMeter()
    mse_losses = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    # mutual information
    cmi = args.cmi * alpha_schedule(epoch, args.akb, 1, strategy="exp")
    dmi = args.dmi * alpha_schedule(epoch, args.akb, 1, strategy="exp")
    kl_beta = alpha_schedule(epoch, args.akb, args.kbm)
    posterior_weight = alpha_schedule(epoch, args.apw, args.pwm)
    # mixup parameters:
    if args.mixup:
        mixup_posterior_kl_losses = AverageMeter()
        # mixup_prior_kl_losses = AverageMeter()
        # mixup_elbo_losses = AverageMeter()
        # mixup_reconstruct_losses = AverageMeter()
    print(
        "Begin {} Epoch Training, CMI:{:.4f} DMI:{:.4f} KL-Beta{:.4f}".format(
            epoch + 1, cmi, dmi, kl_beta))
    for i, (image, *_) in enumerate(train_dloader):
        data_time.update(time.time() - end)
        image = image.float().cuda()
        batch_size = image.size(0)
        reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model(
            image)
        reconstruct_loss, continuous_kl_loss, disc_kl_loss = elbo_criterion(
            image, reconstruction, norm_mean, norm_log_sigma, disc_log_alpha)
        kl_loss = torch.abs(continuous_kl_loss -
                            cmi) + torch.abs(disc_kl_loss - dmi)
        elbo_loss = reconstruct_loss + kl_beta * kl_loss
        elbo_loss.backward()
        if args.mixup:
            with torch.no_grad():
                mixed_image, mixed_z_mean, mixed_z_sigma, mixed_disc_alpha, lam = mixup_vae_data(
                    image,
                    norm_mean,
                    norm_log_sigma,
                    disc_log_alpha,
                    alpha=args.ma)
            mixed_reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model(
                mixed_image)
            # continuous_kl_posterior_loss = kl_norm_criterion(norm_mean, norm_log_sigma, z_mean_gt=mixed_z_mean,
            #                                                  z_sigma_gt=mixed_z_sigma)
            disc_kl_posterior_loss = kl_disc_criterion(disc_log_alpha,
                                                       mixed_disc_alpha)
            continuous_kl_posterior_loss = (F.mse_loss(norm_mean, mixed_z_mean, reduction="sum") + \
                                            F.mse_loss(torch.exp(norm_log_sigma), mixed_z_sigma,
                                                       reduction="sum")) / batch_size
            # disc_kl_posterior_loss = F.mse_loss(torch.exp(disc_log_alpha), mixed_disc_alpha,
            #                                     reduction="sum") / batch_size
            mixup_kl_posterior_loss = posterior_weight * (
                continuous_kl_posterior_loss + disc_kl_posterior_loss)
            mixup_kl_posterior_loss.backward()
            mixup_posterior_kl_losses.update(float(mixup_kl_posterior_loss),
                                             mixed_image.size(0))

        elbo_losses.update(float(elbo_loss), image.size(0))
        reconstruct_losses.update(float(reconstruct_loss), image.size(0))
        kl_losses.update(float(kl_loss), image.size(0))
        # calculate mse_losses if we use bce reconstruction loss
        if args.br:
            mse_loss = F.mse_loss(torch.sigmoid(reconstruction.detach()),
                                  image.detach(),
                                  reduction="sum") / (2 * image.size(0) *
                                                      (args.x_sigma**2))
            mse_losses.update(float(mse_loss), image.size(0))
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'ELBO Loss {elbo_loss.val:.4f} ({elbo_loss.avg:.4f})\t' \
                         'Reconstruct Loss {reconstruct_loss.val:.4f} ({reconstruct_loss.avg:.4f})\t' \
                         'KL Loss {kl_loss.val:.4f} ({kl_loss.avg:.4f})\t'.format(
                epoch, i + 1, len(train_dloader), batch_time=batch_time, data_time=data_time,
                elbo_loss=elbo_losses, reconstruct_loss=reconstruct_losses, kl_loss=kl_losses)
            print(train_text)

    if args.mixup:
        # writer.add_scalar(tag="Train/Mixup-ELBO", scalar_value=mixup_elbo_losses.avg, global_step=epoch + 1)
        # writer.add_scalar(tag="Train/Mixup-Reconstruct", scalar_value=mixup_reconstruct_losses.avg,
        #                   global_step=epoch + 1)
        # writer.add_scalar(tag="Train/Mixup-KL-Prior", scalar_value=mixup_prior_kl_losses.avg, global_step=epoch + 1)
        writer.add_scalar(tag="Train/Mixup-KL-Posterior",
                          scalar_value=mixup_posterior_kl_losses.avg,
                          global_step=epoch + 1)
    writer.add_scalar(tag="Train/ELBO",
                      scalar_value=elbo_losses.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Reconstruct",
                      scalar_value=reconstruct_losses.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/KL",
                      scalar_value=kl_losses.avg,
                      global_step=epoch + 1)
    if args.br:
        writer.add_scalar(tag="Train/MSE",
                          scalar_value=mse_losses.avg,
                          global_step=epoch)
    # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images
    if epoch % args.reconstruct_freq == 0:
        with torch.no_grad():
            image = utils.make_grid(image[:4, ...], nrow=2)
            reconstruct_image = utils.make_grid(torch.sigmoid(
                reconstruction[:4, ...]),
                                                nrow=2)
        writer.add_image(tag="Train/Raw_Image",
                         img_tensor=image,
                         global_step=epoch + 1)
        writer.add_image(tag="Train/Reconstruct_Image",
                         img_tensor=reconstruct_image,
                         global_step=epoch + 1)
    return elbo_losses.avg, reconstruct_losses.avg, kl_losses.avg
예제 #10
0
def test(test_dloader, model, elbo_criterion, epoch, writer):
    reconstruct_losses = AverageMeter()
    kl_losses = AverageMeter()
    elbo_losses = AverageMeter()
    mse_losses = AverageMeter()
    model.eval()
    # mutual information
    cmi = args.cmi * min(1.0, float(epoch / args.adjust_lr[0]))
    dmi = args.dmi * min(1.0, float(epoch / args.adjust_lr[0]))

    for i, (image, *_) in enumerate(test_dloader):
        image = image.float().cuda()
        with torch.no_grad():
            reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model(
                image)
            reconstruct_loss, continuous_kl_loss, disc_kl_loss = elbo_criterion(
                image, reconstruction, norm_mean, norm_log_sigma,
                disc_log_alpha)
            kl_loss = torch.abs(continuous_kl_loss -
                                cmi) + torch.abs(disc_kl_loss - dmi)
            elbo_loss = reconstruct_loss + kl_loss
            if args.br:
                mse_loss = F.mse_loss(torch.sigmoid(reconstruction.detach()),
                                      image.detach(),
                                      reduction="sum") / (2 * image.size(0) *
                                                          (args.x_sigma**2))
                mse_losses.update(float(mse_loss), image.size(0))
        elbo_losses.update(float(elbo_loss), image.size(0))
        reconstruct_losses.update(float(reconstruct_loss), image.size(0))
        kl_losses.update(float(kl_loss), image.size(0))
    writer.add_scalar(tag="Test/ELBO",
                      scalar_value=elbo_losses.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Test/Reconstruct",
                      scalar_value=reconstruct_losses.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Test/KL",
                      scalar_value=kl_losses.avg,
                      global_step=epoch + 1)
    if args.br:
        writer.add_scalar(tag="Test/MSE",
                          scalar_value=mse_losses.avg,
                          global_step=epoch)
    if epoch % args.reconstruct_freq == 0:
        with torch.no_grad():
            image = utils.make_grid(image[:4, ...], nrow=2)
            reconstruct_image = utils.make_grid(torch.sigmoid(
                reconstruction[:4, ...]),
                                                nrow=2)
        writer.add_image(tag="Test/Raw_Image",
                         img_tensor=image,
                         global_step=epoch + 1)
        writer.add_image(tag="Test/Reconstruct_Image",
                         img_tensor=reconstruct_image,
                         global_step=epoch + 1)
    return elbo_losses.avg, reconstruct_losses.avg, kl_losses.avg
예제 #11
0
파일: main.py 프로젝트: FengHZ/DCGAN
def train(train_dloader, generator, discriminator, g_optimizer, d_optimizer, criterion, writer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    generator_loss = AverageMeter()
    discriminator_real_loss = AverageMeter()
    discriminator_fake_loss = AverageMeter()
    dx_record = AverageMeter()
    dgz1_record = AverageMeter()
    dgz2_record = AverageMeter()
    generator.train()
    discriminator.train()
    end = time.time()
    g_optimizer.zero_grad()
    d_optimizer.zero_grad()
    dx_equilibrium = args.dx_equilibrium * (args.decay_equilibrium) ** epoch
    dgz1_equilibrium = args.dgz1_equilibrium * (args.decay_equilibrium) ** epoch
    dgz2_equilibrium = args.dgz2_equilibrium * (args.decay_equilibrium) ** epoch
    for i, (real_image, index, *_) in enumerate(train_dloader):
        data_time.update(time.time() - end)
        real_image = real_image.cuda()
        batch_size = real_image.size(0)
        # create noise for generator
        noise = torch.randn(batch_size, args.latent_dim, 1, 1).cuda()
        # create image label
        real_label = torch.full((batch_size,), args.real_label).cuda()
        # use discriminator to distinguish the real images
        output = discriminator(real_image).view(-1)
        # calculate d_x
        d_x = output.mean().item()
        # calculate the discriminator loss in real image
        d_loss_real = criterion(output, real_label)
        # equilibrium strategy
        if d_x <= 0.5 + dx_equilibrium:
            d_loss_real.backward()
        # use discriminator to distinguish the fake images
        fake = generator(noise)
        # here we only train discriminator, so we use fake.detach()
        output = discriminator(fake.detach()).view(-1)
        # calculate d_gz_1
        d_gz_1 = output.mean().item()
        # create fake label
        fake_label = torch.full((batch_size,), args.fake_label).cuda()
        # calculate the discriminator loss in fake image
        d_loss_fake = criterion(output, fake_label)
        # equilibrium strategy
        if d_gz_1 >= 0.5 - dgz1_equilibrium:
            d_loss_fake.backward()
        # optimize discriminator
        d_optimizer.step()
        # here we train generator to make their generator image looks more real
        # one trick for generator is to use  max log(D) instead of min log(1-D)
        g_label = torch.full((batch_size,), args.real_label).cuda()
        output = discriminator(fake).view(-1)
        # calculate d_gz_2
        d_gz_2 = output.mean().item()
        # calculate the g_loss
        g_loss = criterion(output, g_label)
        if d_gz_2 <= 0.5 - dgz2_equilibrium:
            g_loss.backward()
        g_optimizer.step()
        # zero grad each optimizer
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        # update d loss and g loss

        generator_loss.update(float(g_loss), batch_size)
        discriminator_fake_loss.update(float(d_loss_fake), batch_size)
        discriminator_real_loss.update(float(d_loss_real), batch_size)
        dx_record.update(float(d_x), batch_size)
        dgz1_record.update(float(d_gz_1), batch_size)
        dgz2_record.update(float(d_gz_2), batch_size)
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'D(x) [{3:.4f}] D(G(z)) [{4:.4f}/{5:.4f}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'Discriminator Real Loss {drl.val:.4f} ({drl.avg:.4f})\t' \
                         'Discriminator Fake Loss {dfl.val:.4f} ({dfl.avg:.4f})\t' \
                         'Generator Loss {gl.val:.4f} ({gl.avg:.4f})\t'.format(
                epoch, i + 1, len(train_dloader), d_x, d_gz_1, d_gz_2, batch_time=batch_time,
                data_time=data_time, drl=discriminator_real_loss, dfl=discriminator_fake_loss, gl=generator_loss)
            print(train_text)
    writer.add_scalar(tag="DCGAN/DRL", scalar_value=discriminator_real_loss.avg, global_step=epoch + 1)
    writer.add_scalar(tag="DCGAN/DFL", scalar_value=discriminator_fake_loss.avg, global_step=epoch + 1)
    writer.add_scalar(tag="DCGAN/GL", scalar_value=generator_loss.avg, global_step=epoch + 1)
    writer.add_scalar(tag="DCGAN/dx", scalar_value=dx_record.avg, global_step=epoch + 1)
    writer.add_scalar(tag="DCGAN/dgz1", scalar_value=dgz1_record.avg, global_step=epoch + 1)
    writer.add_scalar(tag="DCGAN/dgz2", scalar_value=dgz2_record.avg, global_step=epoch + 1)
    # in the train end, we want to add some real images and fake images
    real_image = utils.make_grid(real_image[:16, ...], nrow=4)
    writer.add_image(tag="Real_Image", img_tensor=(real_image * 0.5) + 0.5, global_step=epoch + 1)
    noise = torch.randn(16, args.latent_dim, 1, 1).cuda()
    fake_image = utils.make_grid(generator(noise), nrow=4)
    writer.add_image(tag="Fake_Image", img_tensor=(fake_image * 0.5) + 0.5, global_step=epoch + 1)
def train(train_dloader_u, train_dloader_l, model, elbo_criterion,
          cls_criterion, optimizer, epoch, writer, discrete_latent_dim):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    kl_inferences = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    # mutual information
    cmi = alpha_schedule(epoch, args.akb, args.cmi)
    dmi = alpha_schedule(epoch, args.akb, args.dmi)
    # elbo part weight
    ew = alpha_schedule(epoch, args.aew, args.ewm)
    # mixup parameters
    kl_beta_c = alpha_schedule(epoch, args.akb, args.kbmc)
    kl_beta_d = alpha_schedule(epoch, args.akb, args.kbmd)
    pwm = alpha_schedule(epoch, args.apw, args.pwm)
    # unsupervised cls weight
    ucw = alpha_schedule(epoch, round(args.wmf * args.epochs), args.wrd)
    for i, ((image_l, label_l), (image_u, label_u)) in enumerate(
            zip(cycle(train_dloader_l), train_dloader_u)):
        batch_size_l = image_l.size(0)
        batch_size_u = image_u.size(0)
        data_time.update(time.time() - end)
        # for the labeled part, do classification and mixup
        image_l = image_l.float().cuda()
        label_l = label_l.long().cuda()
        label_onehot_l = torch.zeros(batch_size_l,
                                     discrete_latent_dim).cuda().scatter_(
                                         1, label_l.view(-1, 1), 1)
        reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l = model(
            image_l, disc_label=label_l)
        reconstruct_loss_l, continuous_prior_kl_loss_l, disc_prior_kl_loss_l = elbo_criterion(
            image_l, reconstruction_l, norm_mean_l, norm_log_sigma_l,
            disc_log_alpha_l)
        prior_kl_loss_l = kl_beta_c * torch.abs(continuous_prior_kl_loss_l -
                                                cmi) + kl_beta_d * torch.abs(
                                                    disc_prior_kl_loss_l - dmi)
        elbo_loss_l = reconstruct_loss_l + prior_kl_loss_l
        # do optimal transport estimation
        with torch.no_grad():
            smoothed_image_l, smoothed_z_mean_l, smoothed_z_sigma_l, smoothed_disc_alpha_l, smoothed_label_l, smoothed_lambda_l = \
                label_smoothing(
                    image_l,
                    norm_mean_l,
                    norm_log_sigma_l,
                    disc_log_alpha_l,
                    epsilon=args.epsilon,
                    disc_label=label_l)
            smoothed_label_onehot_l = torch.zeros(
                batch_size_l, discrete_latent_dim).cuda().scatter_(
                    1, smoothed_label_l.view(-1, 1), 1)
        smoothed_reconstruction_l, smoothed_norm_mean_l, smoothed_norm_log_sigma_l, smoothed_disc_log_alpha_l, *_ = model(
            smoothed_image_l, True, label_l, smoothed_label_l,
            smoothed_lambda_l)
        disc_posterior_kl_loss_l = smoothed_lambda_l * cls_criterion(
            smoothed_disc_log_alpha_l,
            label_onehot_l) + (1 - smoothed_lambda_l) * cls_criterion(
                smoothed_disc_log_alpha_l, smoothed_label_onehot_l)
        continuous_posterior_kl_loss_l = (F.mse_loss(smoothed_norm_mean_l, smoothed_z_mean_l, reduction="sum") + \
                                          F.mse_loss(torch.exp(smoothed_norm_log_sigma_l), smoothed_z_sigma_l,
                                                     reduction="sum")) / batch_size_l
        elbo_loss_l = elbo_loss_l + kl_beta_c * pwm * continuous_posterior_kl_loss_l
        loss_supervised = ew * elbo_loss_l + disc_posterior_kl_loss_l
        loss_supervised.backward()

        # for the unlabeled part, do classification and mixup
        image_u = image_u.float().cuda()
        label_u = label_u.long().cuda()
        reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u = model(
            image_u)
        # calculate the KL(q_y_x|p_y_x)
        with torch.no_grad():
            label_smooth_u = torch.zeros(
                batch_size_u, discrete_latent_dim).cuda().scatter_(
                    1, label_u.view(-1, 1),
                    1 - 0.001 - 0.001 / (discrete_latent_dim - 1))
            label_smooth_u = label_smooth_u + torch.ones(label_smooth_u.size(
            )).cuda() * 0.001 / (discrete_latent_dim - 1)
            disc_alpha_u = torch.exp(disc_log_alpha_u)
            inference_kl = disc_alpha_u * disc_log_alpha_u - disc_alpha_u * torch.log(
                label_smooth_u)
        kl_inferences.update(float(torch.sum(inference_kl) / batch_size_u),
                             batch_size_u)
        reconstruct_loss_u, continuous_prior_kl_loss_u, disc_prior_kl_loss_u = elbo_criterion(
            image_u, reconstruction_u, norm_mean_u, norm_log_sigma_u,
            disc_log_alpha_u)
        prior_kl_loss_u = kl_beta_c * torch.abs(continuous_prior_kl_loss_u -
                                                cmi) + kl_beta_d * torch.abs(
                                                    disc_prior_kl_loss_u - dmi)
        elbo_loss_u = reconstruct_loss_u + prior_kl_loss_u
        # do mixup part
        with torch.no_grad():
            mixed_image_u, mixed_z_mean_u, mixed_z_sigma_u, mixed_disc_alpha_u, lam_u = \
                mixup_vae_data(
                    image_u,
                    norm_mean_u,
                    norm_log_sigma_u,
                    disc_log_alpha_u,
                    optimal_match=args.om)
        mixed_reconstruction_u, mixed_norm_mean_u, mixed_norm_log_sigma_u, mixed_disc_log_alpha_u, *_ = model(
            mixed_image_u)
        disc_posterior_kl_loss_u = cls_criterion(mixed_disc_log_alpha_u,
                                                 mixed_disc_alpha_u)
        continuous_posterior_kl_loss_u = (F.mse_loss(mixed_norm_mean_u, mixed_z_mean_u, reduction="sum") + \
                                          F.mse_loss(torch.exp(mixed_norm_log_sigma_u), mixed_z_sigma_u,
                                                     reduction="sum")) / batch_size_u
        elbo_loss_u = elbo_loss_u + kl_beta_c * pwm * continuous_posterior_kl_loss_u
        loss_unsupervised = ew * elbo_loss_u + ucw * disc_posterior_kl_loss_u
        loss_unsupervised.backward()
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format(
                epoch, i + 1, len(train_dloader_u), batch_time=batch_time, data_time=data_time)
            print(train_text)
    # record unlabeled part loss
    writer.add_scalar(tag="Train/KL_Inference",
                      scalar_value=kl_inferences.avg,
                      global_step=epoch + 1)
    # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images
    if epoch % args.reconstruct_freq == 0:
        with torch.no_grad():
            image = utils.make_grid(image_u[:4, ...], nrow=2)
            reconstruct_image = utils.make_grid(torch.sigmoid(
                reconstruction_u[:4, ...]),
                                                nrow=2)
        writer.add_image(tag="Train/Raw_Image",
                         img_tensor=image,
                         global_step=epoch + 1)
        writer.add_image(tag="Train/Reconstruct_Image",
                         img_tensor=reconstruct_image,
                         global_step=epoch + 1)
예제 #13
0
def train(train_dloader_u,
          train_dloader_l,
          model,
          elbo_criterion,
          cls_criterion,
          optimizer,
          epoch,
          writer,
          discrete_latent_dim,
          kl_norm_criterion=None,
          kl_disc_criterion=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    # train_dloader_u part
    reconstruct_losses_u = AverageMeter()
    continuous_prior_kl_losses_u = AverageMeter()
    discrete_prior_kl_losses_u = AverageMeter()
    elbo_losses_u = AverageMeter()
    mse_losses_u = AverageMeter()
    continuous_posterior_kl_losses_u = AverageMeter()
    discrete_posterior_kl_losses_u = AverageMeter()
    # train_dloader_l part
    reconstruct_losses_l = AverageMeter()
    continuous_prior_kl_losses_l = AverageMeter()
    discrete_prior_kl_losses_l = AverageMeter()
    elbo_losses_l = AverageMeter()
    mse_losses_l = AverageMeter()
    continuous_posterior_kl_losses_l = AverageMeter()
    discrete_posterior_kl_losses_l = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    # mutual information
    cmi = alpha_schedule(epoch, args.akb, args.cmi, strategy="exp")
    dmi = alpha_schedule(epoch, args.akb, args.dmi, strategy="exp")
    # elbo part weight
    ew = alpha_schedule(epoch, args.aew, args.ewm)
    # mixup parameters
    kl_beta_c = alpha_schedule(epoch, args.akb, args.kbmc)
    kl_beta_d = alpha_schedule(epoch, args.akb, args.kbmd)
    pwm = alpha_schedule(epoch, args.apw, args.pwm)
    # unsupervised cls weight
    ucw = alpha_schedule(epoch, round(args.wmf * args.epochs), args.wrd)
    for i, ((image_l, label_l),
            (image_u,
             _)) in enumerate(zip(cycle(train_dloader_l), train_dloader_u)):
        if image_l.size(0) != image_u.size(0):
            batch_size = min(image_l.size(0), image_u.size(0))
            image_l = image_l[0:batch_size]
            label_l = label_l[0:batch_size]
            image_u = image_u[0:batch_size]
        else:
            batch_size = image_l.size(0)
        data_time.update(time.time() - end)
        # for the labeled part, do classification and mixup
        image_l = image_l.float().cuda()
        label_l = label_l.long().cuda()
        label_onehot_l = torch.zeros(batch_size,
                                     discrete_latent_dim).cuda().scatter_(
                                         1, label_l.view(-1, 1), 1)
        reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l = model(
            image_l, disc_label=label_l)
        reconstruct_loss_l, continuous_prior_kl_loss_l, disc_prior_kl_loss_l = elbo_criterion(
            image_l, reconstruction_l, norm_mean_l, norm_log_sigma_l,
            disc_log_alpha_l)
        prior_kl_loss_l = kl_beta_c * torch.abs(continuous_prior_kl_loss_l -
                                                cmi) + kl_beta_d * torch.abs(
                                                    disc_prior_kl_loss_l - dmi)
        elbo_loss_l = reconstruct_loss_l + prior_kl_loss_l
        reconstruct_losses_l.update(float(reconstruct_loss_l.detach().item()),
                                    batch_size)
        continuous_prior_kl_losses_l.update(
            float(continuous_prior_kl_loss_l.detach().item()), batch_size)
        discrete_prior_kl_losses_l.update(
            float(disc_prior_kl_loss_l.detach().item()), batch_size)
        # do optimal transport estimation
        with torch.no_grad():
            mixed_image_l, mixed_z_mean_l, mixed_z_sigma_l, mixed_disc_alpha_l, label_mixup_l, lam_l = \
                mixup_vae_data(
                    image_l,
                    norm_mean_l,
                    norm_log_sigma_l,
                    disc_log_alpha_l,
                    alpha=args.mas,
                    disc_label=label_l)
            label_mixup_onehot_l = torch.zeros(
                batch_size,
                discrete_latent_dim).cuda().scatter_(1,
                                                     label_mixup_l.view(-1, 1),
                                                     1)
        mixed_reconstruction_l, mixed_norm_mean_l, mixed_norm_log_sigma_l, mixed_disc_log_alpha_l, *_ = model(
            mixed_image_l,
            mixup=True,
            disc_label=label_l,
            disc_label_mixup=label_mixup_l,
            mixup_lam=lam_l)
        disc_posterior_kl_loss_l = lam_l * cls_criterion(
            mixed_disc_log_alpha_l,
            label_onehot_l) + (1 - lam_l) * cls_criterion(
                mixed_disc_log_alpha_l, label_mixup_onehot_l)
        continuous_posterior_kl_loss_l = (F.mse_loss(mixed_norm_mean_l, mixed_z_mean_l, reduction="sum") + \
                                          F.mse_loss(torch.exp(mixed_norm_log_sigma_l), mixed_z_sigma_l,
                                                     reduction="sum")) / batch_size
        elbo_loss_l = elbo_loss_l + kl_beta_c * pwm * continuous_posterior_kl_loss_l
        elbo_losses_l.update(float(elbo_loss_l.detach().item()), batch_size)
        continuous_posterior_kl_losses_l.update(
            float(continuous_posterior_kl_loss_l.detach().item()), batch_size)
        discrete_posterior_kl_losses_l.update(
            float(disc_posterior_kl_loss_l.detach().item()), batch_size)
        if args.br:
            mse_loss_l = F.mse_loss(torch.sigmoid(reconstruction_l.detach()),
                                    image_l.detach(),
                                    reduction="sum") / (2 * batch_size *
                                                        (args.x_sigma**2))
            mse_losses_l.update(float(mse_loss_l), batch_size)
        loss_supervised = ew * elbo_loss_l + disc_posterior_kl_loss_l
        loss_supervised.backward()

        # for the unlabeled part, do classification and mixup
        image_u = image_u.float().cuda()
        reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u = model(
            image_u)
        reconstruct_loss_u, continuous_prior_kl_loss_u, disc_prior_kl_loss_u = elbo_criterion(
            image_u, reconstruction_u, norm_mean_u, norm_log_sigma_u,
            disc_log_alpha_u)
        prior_kl_loss_u = kl_beta_c * torch.abs(continuous_prior_kl_loss_u -
                                                cmi) + kl_beta_d * torch.abs(
                                                    disc_prior_kl_loss_u - dmi)
        elbo_loss_u = reconstruct_loss_u + prior_kl_loss_u
        reconstruct_losses_u.update(float(reconstruct_loss_u.detach().item()),
                                    batch_size)
        continuous_prior_kl_losses_u.update(
            float(continuous_prior_kl_loss_u.detach().item()), batch_size)
        discrete_prior_kl_losses_u.update(
            float(disc_prior_kl_loss_u.detach().item()), batch_size)
        # do mixup part
        with torch.no_grad():
            mixed_image_u, mixed_z_mean_u, mixed_z_sigma_u, mixed_disc_alpha_u, lam_u = \
                mixup_vae_data(
                    image_u,
                    norm_mean_u,
                    norm_log_sigma_u,
                    disc_log_alpha_u,
                    alpha=args.mau)
        mixed_reconstruction_u, mixed_norm_mean_u, mixed_norm_log_sigma_u, mixed_disc_log_alpha_u, *_ = model(
            mixed_image_u)
        disc_posterior_kl_loss_u = cls_criterion(mixed_disc_log_alpha_u,
                                                 mixed_disc_alpha_u)
        continuous_posterior_kl_loss_u = (F.mse_loss(mixed_norm_mean_u, mixed_z_mean_u, reduction="sum") + \
                                          F.mse_loss(torch.exp(mixed_norm_log_sigma_u), mixed_z_sigma_u,
                                                     reduction="sum")) / batch_size
        elbo_loss_u = elbo_loss_u + kl_beta_c * pwm * continuous_posterior_kl_loss_u
        loss_unsupervised = ew * elbo_loss_u + ucw * disc_posterior_kl_loss_u
        elbo_losses_u.update(float(elbo_loss_u.detach().item()), batch_size)
        continuous_posterior_kl_losses_u.update(
            float(continuous_posterior_kl_loss_u.detach().item()), batch_size)
        discrete_posterior_kl_losses_u.update(
            float(disc_posterior_kl_loss_u.detach().item()), batch_size)
        loss_unsupervised.backward()
        if args.br:
            mse_loss_u = F.mse_loss(torch.sigmoid(reconstruction_u.detach()),
                                    image_u.detach(),
                                    reduction="sum") / (2 * batch_size *
                                                        (args.x_sigma**2))
            mse_losses_u.update(float(mse_loss_u), batch_size)
        optimizer.step()
        optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'Cls Loss Labeled {cls_loss_l.val:.4f} ({cls_loss_l.avg:.4f})\t' \
                         'Cls Loss Unlabeled {cls_loss_u.val:.4f} ({cls_loss_u.avg:.4f})\t' \
                         'Continuous Prior KL Loss Labeled {cpk_loss_l.val:.4f} ({cpk_loss_l.avg:.4f})\t' \
                         'Continuous Prior KL Loss Unlabeled {cpk_loss_u.val:.4f} ({cpk_loss_u.avg:.4f})\t'.format(
                epoch, i + 1, len(train_dloader_u), batch_time=batch_time, data_time=data_time,
                cls_loss_l=discrete_posterior_kl_losses_l, cls_loss_u=discrete_posterior_kl_losses_u,
                cpk_loss_l=continuous_prior_kl_losses_l, cpk_loss_u=continuous_prior_kl_losses_u)
            print(train_text)
    # record unlabeled part loss
    writer.add_scalar(tag="Train/ELBO_U",
                      scalar_value=elbo_losses_u.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Reconstrut_U",
                      scalar_value=reconstruct_losses_u.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Continuous_Prior_KL_U",
                      scalar_value=continuous_prior_kl_losses_u.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Continuous_Posterior_KL_U",
                      scalar_value=continuous_posterior_kl_losses_u.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Discrete_Prior_KL_U",
                      scalar_value=discrete_prior_kl_losses_u.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Discrete_Posterior_KL_U",
                      scalar_value=discrete_posterior_kl_losses_u.avg,
                      global_step=epoch + 1)
    if args.br:
        writer.add_scalar(tag="Train/MSE_U",
                          scalar_value=mse_losses_u.avg,
                          global_step=epoch + 1)
    # record labeled part loss
    writer.add_scalar(tag="Train/ELBO_L",
                      scalar_value=elbo_losses_l.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Reconstruct_L",
                      scalar_value=reconstruct_losses_l.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Continuous_Prior_KL_L",
                      scalar_value=continuous_prior_kl_losses_l.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Continuous_Posterior_KL_L",
                      scalar_value=continuous_posterior_kl_losses_l.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Discrete_Prior_KL_L",
                      scalar_value=discrete_prior_kl_losses_l.avg,
                      global_step=epoch + 1)
    writer.add_scalar(tag="Train/Discrete_Posterior_KL_L",
                      scalar_value=discrete_posterior_kl_losses_l.avg,
                      global_step=epoch + 1)
    if args.br:
        writer.add_scalar(tag="Train/MSE_L",
                          scalar_value=mse_losses_l.avg,
                          global_step=epoch + 1)
    # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images
    if epoch % args.reconstruct_freq == 0:
        with torch.no_grad():
            image = utils.make_grid(image_u[:4, ...], nrow=2)
            reconstruct_image = utils.make_grid(torch.sigmoid(
                reconstruction_u[:4, ...]),
                                                nrow=2)
        writer.add_image(tag="Train/Raw_Image",
                         img_tensor=image,
                         global_step=epoch + 1)
        writer.add_image(tag="Train/Reconstruct_Image",
                         img_tensor=reconstruct_image,
                         global_step=epoch + 1)
    return discrete_posterior_kl_losses_u.avg, discrete_posterior_kl_losses_l.avg
예제 #14
0
def valid(valid_dloader, model, criterion, epoch, writer):
    """
    Here valid dataloader we may need to organize it with each study
    """
    model.eval()
    # calculate score and label for different dataset
    atelectasis_score_dict = defaultdict(list)
    atelectasis_label_dict = defaultdict(list)
    cardiomegaly_score_dict = defaultdict(list)
    cardiomegaly_label_dict = defaultdict(list)
    consolidation_score_dict = defaultdict(list)
    consolidation_label_dict = defaultdict(list)
    edema_score_dict = defaultdict(list)
    edema_label_dict = defaultdict(list)
    pleural_effusion_score_dict = defaultdict(list)
    pleural_effusion_label_dict = defaultdict(list)
    # calculate index for valid dataset
    losses = AverageMeter()
    for idx, (image, index, image_name, label_weight,
              label) in enumerate(valid_dloader):
        image = image.float().cuda()
        label = label.long().cuda()
        label_weight = label_weight.cuda()
        with torch.no_grad():
            prediction_list = model(image)
            loss = criterion(prediction_list, label_weight, label)
            losses.update(float(loss.item()), image.size(0))
        for i, img_name in enumerate(image_name):
            study_name = re.match(r"(.*)patient(.*)\|(.*)", img_name).group(2)
            for j, prediction in enumerate(prediction_list):
                score = prediction[i, 1].item()
                item_label = label[i, j].item()
                if j == 0:
                    atelectasis_label_dict[study_name].append(item_label)
                    atelectasis_score_dict[study_name].append(score)
                elif j == 1:
                    cardiomegaly_label_dict[study_name].append(item_label)
                    cardiomegaly_score_dict[study_name].append(score)
                elif j == 2:
                    consolidation_label_dict[study_name].append(item_label)
                    consolidation_score_dict[study_name].append(score)
                elif j == 3:
                    edema_label_dict[study_name].append(item_label)
                    edema_score_dict[study_name].append(score)
                else:
                    pleural_effusion_label_dict[study_name].append(item_label)
                    pleural_effusion_score_dict[study_name].append(score)
    writer.add_scalar(tag="Valid/cls_loss",
                      scalar_value=losses.avg,
                      global_step=epoch + 1)
    # Calculate AUC ROC
    # Here we use the max method to get the score and label list of each study
    atelectasis_score, atelectasis_label = get_score_label_array_from_dict(
        atelectasis_score_dict, atelectasis_label_dict)
    cardiomegaly_score, cardiomegaly_label = get_score_label_array_from_dict(
        cardiomegaly_score_dict, cardiomegaly_label_dict)

    consolidation_score, consolidation_label = get_score_label_array_from_dict(
        consolidation_score_dict, consolidation_label_dict)
    edema_score, edema_label = get_score_label_array_from_dict(
        edema_score_dict, edema_label_dict)
    pleural_effusion_score, pleural_effusion_label = get_score_label_array_from_dict(
        pleural_effusion_score_dict, pleural_effusion_label_dict)
    atelectasis_auc = roc_auc_score(atelectasis_label, atelectasis_score)
    cardiomegaly_auc = roc_auc_score(cardiomegaly_label, cardiomegaly_score)
    consolidation_auc = roc_auc_score(consolidation_label, consolidation_score)
    edema_auc = roc_auc_score(edema_label, edema_score)
    pleural_effusion_auc = roc_auc_score(pleural_effusion_label,
                                         pleural_effusion_score)
    writer.add_scalar(tag="Valid/Atelectasis_AUC",
                      scalar_value=atelectasis_auc,
                      global_step=epoch)
    writer.add_scalar(tag="Valid/Cardiomegaly_AUC",
                      scalar_value=cardiomegaly_auc,
                      global_step=epoch)
    writer.add_scalar(tag="Valid/Consolidation_AUC",
                      scalar_value=consolidation_auc,
                      global_step=epoch)
    writer.add_scalar(tag="Valid/Edema_AUC",
                      scalar_value=edema_auc,
                      global_step=epoch)
    writer.add_scalar(tag="Valid/Pleural_Effusion_AUC",
                      scalar_value=pleural_effusion_auc,
                      global_step=epoch)
    return [
        atelectasis_auc, cardiomegaly_auc, consolidation_auc, edema_auc,
        pleural_effusion_auc
    ]
예제 #15
0
def train(train_dloader, model, lemniscate, criterion, optimizer, epoch,
          writer, dataset):
    # record the time for loading a data and do backward for a batch
    # also record the loss value
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    model.train()
    end = time.time()
    optimizer.zero_grad()
    for i, (image, label, *_) in enumerate(train_dloader):
        # print(image.size())
        # print(label.size())
        # input()
        data_time.update(time.time() - end)
        label = label.cuda()
        image = image.float().cuda()
        feature = model(image)
        output = lemniscate(feature, label)
        loss = criterion(output, label) / args.iter_size
        loss.backward()
        losses.update(float(loss) * args.iter_size, image.size(0))
        if (i + 1) % args.iter_size == 0:
            optimizer.step()
            optimizer.zero_grad()
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            train_text = 'Epoch: [{0}][{1}/{2}]\t' \
                         'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                         'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                         'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                epoch, i, len(train_dloader), batch_time=batch_time,
                data_time=data_time, loss=losses)
            print(train_text)
    writer.add_scalar(tag="{}_train/loss".format(dataset),
                      scalar_value=losses.avg,
                      global_step=epoch)
    return losses.avg