def test(test_dloader, model, elbo_criterion, epoch, writer): reconstruct_losses = AverageMeter() kl_losses = AverageMeter() elbo_losses = AverageMeter() mse_losses = AverageMeter() model.eval() # mutual information cmi = args.cmi * min(1.0, float(epoch / args.adjust_lr[0])) dmi = args.dmi * min(1.0, float(epoch / args.adjust_lr[0])) for i, (image, *_) in enumerate(test_dloader): image = image.float().cuda() with torch.no_grad(): reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model( image) reconstruct_loss, continuous_kl_loss, disc_kl_loss = elbo_criterion( image, reconstruction, norm_mean, norm_log_sigma, disc_log_alpha) kl_loss = torch.abs(continuous_kl_loss - cmi) + torch.abs(disc_kl_loss - dmi) elbo_loss = reconstruct_loss + kl_loss if args.br: mse_loss = F.mse_loss(torch.sigmoid(reconstruction.detach()), image.detach(), reduction="sum") / (2 * image.size(0) * (args.x_sigma**2)) mse_losses.update(float(mse_loss), image.size(0)) elbo_losses.update(float(elbo_loss), image.size(0)) reconstruct_losses.update(float(reconstruct_loss), image.size(0)) kl_losses.update(float(kl_loss), image.size(0)) writer.add_scalar(tag="Test/ELBO", scalar_value=elbo_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Test/Reconstruct", scalar_value=reconstruct_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Test/KL", scalar_value=kl_losses.avg, global_step=epoch + 1) if args.br: writer.add_scalar(tag="Test/MSE", scalar_value=mse_losses.avg, global_step=epoch) if epoch % args.reconstruct_freq == 0: with torch.no_grad(): image = utils.make_grid(image[:4, ...], nrow=2) reconstruct_image = utils.make_grid(torch.sigmoid( reconstruction[:4, ...]), nrow=2) writer.add_image(tag="Test/Raw_Image", img_tensor=image, global_step=epoch + 1) writer.add_image(tag="Test/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1) return elbo_losses.avg, reconstruct_losses.avg, kl_losses.avg
def train(train_dloader, model, criterion, optimizer, epoch, writer, dataset): # record the time for loading a data and do backward for a batch # also record the loss value batch_time = AverageMeter() data_time = AverageMeter() reconstruct_losses = AverageMeter() kl_losses = AverageMeter() losses = AverageMeter() model.train() end = time.time() optimizer.zero_grad() for i, (image, *_) in enumerate(train_dloader): data_time.update(time.time() - end) image = image.float().cuda() image_reconstructed = model(image) reconstruct_loss, kl_loss = criterion(image, image_reconstructed, model.z_mean, model.z_log_sigma, model.z_sigma) loss = reconstruct_loss + kl_loss # loss = kl_loss loss.backward() if args.gd_clip_flag: torch.nn.utils.clip_grad_value_(model.parameters(), args.gd_clip_value) losses.update(float(loss), image.size(0)) reconstruct_losses.update(float(reconstruct_loss), image.size(0)) kl_losses.update(float(kl_loss), image.size(0)) optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Total Loss {total_loss.val:.4f} ({total_loss.avg:.4f})\t' \ 'Reconstruct Loss {reconstruct_loss.val:.4f} ({reconstruct_loss.avg:.4f})\t' \ 'KL Loss {kl_loss.val:.4f} ({kl_loss.avg:.4f})\t'.format( epoch, i, len(train_dloader), batch_time=batch_time, data_time=data_time, total_loss=losses, reconstruct_loss=reconstruct_losses, kl_loss=kl_losses) print(train_text) writer.add_scalar(tag="{}_train/loss".format(dataset), scalar_value=losses.avg, global_step=epoch) writer.add_scalar(tag="{}_train/reconstruct_loss".format(dataset), scalar_value=reconstruct_losses.avg, global_step=epoch) writer.add_scalar(tag="{}_train/kl_loss".format(dataset), scalar_value=kl_losses.avg, global_step=epoch) return losses.avg, reconstruct_losses.avg, kl_losses.avg
def test(test_dloader, model, elbo_criterion, epoch, writer, discrete_latent_dim): continuous_kl_losses = AverageMeter() discrete_kl_losses = AverageMeter() mse_losses = AverageMeter() elbo_losses = AverageMeter() model.eval() all_score = [] all_label = [] for i, (image, label) in enumerate(test_dloader): image = image.float().cuda() label = label.long().cuda() label_onehot = torch.zeros(label.size(0), discrete_latent_dim).cuda().scatter_(1, label.view(-1, 1), 1) batch_size = image.size(0) with torch.no_grad(): reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model(image) reconstruct_loss, continuous_kl_loss, discrete_kl_loss = elbo_criterion(image, reconstruction, norm_mean, norm_log_sigma, disc_log_alpha) mse_loss = F.mse_loss(torch.sigmoid(reconstruction.detach()), image.detach(), reduction="sum") / ( 2 * image.size(0) * (args.x_sigma ** 2)) mse_losses.update(float(mse_loss), image.size(0)) all_score.append(torch.exp(disc_log_alpha)) all_label.append(label_onehot) continuous_kl_losses.update(float(continuous_kl_loss.item()), batch_size) discrete_kl_losses.update(float(discrete_kl_loss.item()), batch_size) elbo_losses.update(float(mse_loss + 0.01*(continuous_kl_loss + discrete_kl_loss)), image.size(0)) writer.add_scalar(tag="Test/KL(q(z|X)||p(z))", scalar_value=continuous_kl_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Test/KL(q(y|X)||p(y))", scalar_value=discrete_kl_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Test/log(p(X|z,y))", scalar_value=mse_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Test/ELBO", scalar_value=elbo_losses.avg, global_step=epoch + 1) all_score = torch.cat(all_score, dim=0).detach() all_label = torch.cat(all_label, dim=0).detach() _, y_true = torch.topk(all_label, k=1, dim=1) _, y_pred = torch.topk(all_score, k=5, dim=1) # calculate accuracy by hand test_top_1_accuracy = float(torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) test_top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar(tag="Test/top1 accuracy", scalar_value=test_top_1_accuracy, global_step=epoch + 1) if args.dataset == "Cifar100": writer.add_scalar(tag="Test/top 5 accuracy", scalar_value=test_top_5_accuracy, global_step=epoch + 1) if epoch % args.reconstruct_freq == 0: with torch.no_grad(): image = utils.make_grid(image[:4, ...], nrow=2) reconstruct_image = utils.make_grid(torch.sigmoid(reconstruction[:4, ...]), nrow=2) writer.add_image(tag="Test/Raw_Image", img_tensor=image, global_step=epoch + 1) writer.add_image(tag="Test/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1) return test_top_1_accuracy, test_top_5_accuracy
def train(train_dloader, model, criterion, optimizer, epoch, writer): # some records batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model.train() end = time.time() optimizer.zero_grad() for i, (image, label) in enumerate(train_dloader): data_time.update(time.time() - end) image = image.float().cuda() label = label.long().cuda() cls_result = model(image) loss = criterion(cls_result, label) loss.backward() losses.update(float(loss.item()), image.size(0)) optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Cls Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})'.format( epoch, i + 1, len(train_dloader), batch_time=batch_time, data_time=data_time, cls_loss=losses) print(train_text) writer.add_scalar(tag="Train/cls_loss", scalar_value=losses.avg, global_step=epoch + 1) return losses.avg
def train(train_dloader, model, lemniscate, criterion, optimizer, epoch, writer, dataset): # record the time for loading a data and do backward for a batch # also record the loss value batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model.train() end = time.time() optimizer.zero_grad() for i, (image, label, *_) in enumerate(train_dloader): # print(image.size()) # print(label.size()) # input() data_time.update(time.time() - end) label = label.cuda() image = image.float().cuda() feature = model(image) output = lemniscate(feature, label) loss = criterion(output, label) / args.iter_size loss.backward() losses.update(float(loss) * args.iter_size, image.size(0)) if (i + 1) % args.iter_size == 0: optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch, i, len(train_dloader), batch_time=batch_time, data_time=data_time, loss=losses) print(train_text) writer.add_scalar(tag="{}_train/loss".format(dataset), scalar_value=losses.avg, global_step=epoch) return losses.avg
def test(valid_dloader, test_dloader, model, criterion, epoch, writer, num_classes, zca_mean=None, zca_components=None): model.eval() # calculate result for valid dataset losses = AverageMeter() all_score = [] all_label = [] for i, (image, label) in enumerate(valid_dloader): image = image.float().cuda() if args.zca: image = apply_zca(image, zca_mean, zca_components) label = label.long().cuda() with torch.no_grad(): cls_result = model(image) label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_( 1, label.view(-1, 1), 1) loss = criterion(cls_result, label) losses.update(float(loss.item()), image.size(0)) # here we add the all score and all label into one list all_score.append(torch.softmax(cls_result, dim=1)) # turn label into one-hot code all_label.append(label_onehot) writer.add_scalar(tag="Valid/cls_loss", scalar_value=losses.avg, global_step=epoch + 1) all_score = torch.cat(all_score, dim=0).detach() all_label = torch.cat(all_label, dim=0).detach() _, y_true = torch.topk(all_label, k=1, dim=1) _, y_pred = torch.topk(all_score, k=5, dim=1) # calculate accuracy by hand top_1_accuracy = float( torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar(tag="Valid/top 1 accuracy", scalar_value=top_1_accuracy, global_step=epoch + 1) if args.dataset == "Cifar100": writer.add_scalar(tag="Valid/top 5 accuracy", scalar_value=top_5_accuracy, global_step=epoch + 1) # calculate result for test dataset losses = AverageMeter() all_score = [] all_label = [] for i, (image, label) in enumerate(test_dloader): image = image.float().cuda() if args.zca: image = apply_zca(image, zca_mean, zca_components) label = label.long().cuda() with torch.no_grad(): cls_result = model(image) label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_( 1, label.view(-1, 1), 1) loss = criterion(cls_result, label) losses.update(float(loss.item()), image.size(0)) # here we add the all score and all label into one list all_score.append(torch.softmax(cls_result, dim=1)) # turn label into one-hot code all_label.append(label_onehot) writer.add_scalar(tag="Test/cls_loss", scalar_value=losses.avg, global_step=epoch + 1) all_score = torch.cat(all_score, dim=0).detach() all_label = torch.cat(all_label, dim=0).detach() _, y_true = torch.topk(all_label, k=1, dim=1) _, y_pred = torch.topk(all_score, k=5, dim=1) # calculate accuracy by hand top_1_accuracy = float( torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar(tag="Test/top 1 accuracy", scalar_value=top_1_accuracy, global_step=epoch + 1) if args.dataset == "Cifar100": writer.add_scalar(tag="Test/top 5 accuracy", scalar_value=top_5_accuracy, global_step=epoch + 1) return losses.avg
def train(train_dloader_l, train_dloader_u, model, criterion_l, criterion_u, optimizer, epoch, writer, alpha, zca_mean=None, zca_components=None): # some records batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_supervised = AverageMeter() losses_unsupervised = AverageMeter() model.train() end = time.time() optimizer.zero_grad() epoch_length = len(train_dloader_u) i = 0 for (image_l, label_l), (image_u, _) in zip(cycle(train_dloader_l), train_dloader_u): if image_l.size(0) != image_u.size(0): bt_size = min(image_l.size(0), image_u.size(0)) image_l = image_l[0:bt_size] image_u = image_u[0:bt_size] label_l = label_l[0:bt_size] else: bt_size = image_l.size(0) data_time.update(time.time() - end) image_l = image_l.float().cuda() image_u = image_u.float().cuda() label_l = label_l.long().cuda() if args.zca: image_l = apply_zca(image_l, zca_mean, zca_components) image_u = apply_zca(image_u, zca_mean, zca_components) if args.mixup: mixed_image_l, label_a, label_b, lam = mixup_data( image_l, label_l, args.mas) cls_result_l = model(mixed_image_l) loss_supervised = mixup_criterion(criterion_l, cls_result_l, label_a, label_b, lam) # here label_u_approx is not with any grad with torch.no_grad(): label_u_approx = torch.softmax(model(image_u), dim=1) mixed_image_u, label_a_approx, label_b_approx, lam = mixup_data( image_u, label_u_approx, args.mau) cls_result_u = model(mixed_image_u) cls_result_u = torch.log_softmax(cls_result_u, dim=1) label_u_approx_mixup = lam * label_a_approx + ( 1 - lam) * label_b_approx loss_unsupervised = -1 * torch.mean( torch.sum(label_u_approx_mixup * cls_result_u, dim=1)) loss = loss_supervised + alpha * loss_unsupervised elif args.manifold_mixup: cls_result_l, label_a, label_b, lam = model( image_l, mixup_alpha=args.mas, label=label_l, manifold_mixup=True, mixup_layer_list=args.mll) loss_supervised = mixup_criterion(criterion_l, cls_result_l, label_a, label_b, lam) # here label_u_approx is not with any grad with torch.no_grad(): label_u_approx = torch.softmax(model(image_u), dim=1) cls_result_u, label_a_approx, label_b_approx, lam = model( image_u, mixup_alpha=args.mas, label=label_u_approx, manifold_mixup=True, mixup_layer_list=args.mll) cls_result_u = torch.softmax(cls_result_u, dim=1) label_u_approx_mixup = lam * label_a_approx + ( 1 - lam) * label_b_approx loss_unsupervised = criterion_u(cls_result_u, label_u_approx_mixup) loss = loss_supervised + 10 * alpha * loss_unsupervised else: cls_result_l = model(image_l) loss = criterion_l(cls_result_l, label_l) loss_supervised = loss.detach() loss_unsupervised = torch.zeros(loss.size()) loss.backward() losses.update(float(loss.item()), bt_size) losses_supervised.update(float(loss_supervised.item()), bt_size) losses_unsupervised.update(float(loss_unsupervised.item()), bt_size) optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Cls Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t' \ 'Regularization Loss {reg_loss.val:.4f} ({reg_loss.avg:.4f})\t' \ 'Total Loss {total_loss.val:.4f} ({total_loss.avg:.4f})\t'.format( epoch, i + 1, epoch_length, batch_time=batch_time, data_time=data_time, cls_loss=losses_supervised, reg_loss=losses_unsupervised, total_loss=losses) print(train_text) i += 1 writer.add_scalar(tag="Train/cls_loss", scalar_value=losses_supervised.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/reg_loss", scalar_value=losses_unsupervised.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/total_loss", scalar_value=losses.avg, global_step=epoch + 1) return losses.avg
def test(target_domain, source_domains, test_dloader_list, model_list, classifier_list, epoch, writer, num_classes=126, top_5_accuracy=True): source_domain_losses = [AverageMeter() for i in source_domains] target_domain_losses = AverageMeter() task_criterion = nn.CrossEntropyLoss().cuda() for model in model_list: model.eval() for classifier in classifier_list: classifier.eval() # calculate loss, accuracy for target domain tmp_score = [] tmp_label = [] test_dloader_t = test_dloader_list[0] for _, (image_t, label_t) in enumerate(test_dloader_t): image_t = image_t.cuda() label_t = label_t.long().cuda() with torch.no_grad(): output_t = classifier_list[0](model_list[0](image_t)) label_onehot_t = torch.zeros(label_t.size(0), num_classes).cuda().scatter_( 1, label_t.view(-1, 1), 1) task_loss_t = task_criterion(output_t, label_t) target_domain_losses.update(float(task_loss_t.item()), image_t.size(0)) tmp_score.append(torch.softmax(output_t, dim=1)) # turn label into one-hot code tmp_label.append(label_onehot_t) writer.add_scalar(tag="Test/target_domain_{}_loss".format(target_domain), scalar_value=target_domain_losses.avg, global_step=epoch + 1) tmp_score = torch.cat(tmp_score, dim=0).detach() tmp_label = torch.cat(tmp_label, dim=0).detach() _, y_true = torch.topk(tmp_label, k=1, dim=1) _, y_pred = torch.topk(tmp_score, k=5, dim=1) top_1_accuracy_t = float( torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) writer.add_scalar(tag="Test/target_domain_{}_accuracy_top1".format( target_domain).format(target_domain), scalar_value=top_1_accuracy_t, global_step=epoch + 1) if top_5_accuracy: top_5_accuracy_t = float( torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar(tag="Test/target_domain_{}_accuracy_top5".format( target_domain).format(target_domain), scalar_value=top_5_accuracy_t, global_step=epoch + 1) print("Target Domain {} Accuracy Top1 :{:.3f} Top5:{:.3f}".format( target_domain, top_1_accuracy_t, top_5_accuracy_t)) else: print("Target Domain {} Accuracy {:.3f}".format( target_domain, top_1_accuracy_t)) # calculate loss, accuracy for source domains for s_i, domain_s in enumerate(source_domains): tmp_score = [] tmp_label = [] test_dloader_s = test_dloader_list[s_i + 1] for _, (image_s, label_s) in enumerate(test_dloader_s): image_s = image_s.cuda() label_s = label_s.long().cuda() with torch.no_grad(): output_s = classifier_list[s_i + 1](model_list[s_i + 1](image_s)) label_onehot_s = torch.zeros(label_s.size(0), num_classes).cuda().scatter_( 1, label_s.view(-1, 1), 1) task_loss_s = task_criterion(output_s, label_s) source_domain_losses[s_i].update(float(task_loss_s.item()), image_s.size(0)) tmp_score.append(torch.softmax(output_s, dim=1)) # turn label into one-hot code tmp_label.append(label_onehot_s) writer.add_scalar(tag="Test/source_domain_{}_loss".format(domain_s), scalar_value=source_domain_losses[s_i].avg, global_step=epoch + 1) tmp_score = torch.cat(tmp_score, dim=0).detach() tmp_label = torch.cat(tmp_label, dim=0).detach() _, y_true = torch.topk(tmp_label, k=1, dim=1) _, y_pred = torch.topk(tmp_score, k=5, dim=1) top_1_accuracy_s = float( torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) writer.add_scalar( tag="Test/source_domain_{}_accuracy_top1".format(domain_s), scalar_value=top_1_accuracy_s, global_step=epoch + 1) if top_5_accuracy: top_5_accuracy_s = float( torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar( tag="Test/source_domain_{}_accuracy_top5".format(domain_s), scalar_value=top_5_accuracy_s, global_step=epoch + 1)
def train(train_dloader_u, train_dloader_l, model, elbo_criterion, cls_criterion, optimizer, epoch, writer, discrete_latent_dim): batch_time = AverageMeter() data_time = AverageMeter() kl_inferences = AverageMeter() model.train() end = time.time() optimizer.zero_grad() # mutual information cmi = alpha_schedule(epoch, args.akb, args.cmi) dmi = alpha_schedule(epoch, args.akb, args.dmi) # elbo part weight ew = alpha_schedule(epoch, args.aew, args.ewm) # mixup parameters kl_beta_c = alpha_schedule(epoch, args.akb, args.kbmc) kl_beta_d = alpha_schedule(epoch, args.akb, args.kbmd) for i, ((image_l, label_l), (image_u, label_u)) in enumerate(zip(cycle(train_dloader_l), train_dloader_u)): if image_l.size(0) != image_u.size(0): batch_size = min(image_l.size(0), image_u.size(0)) image_l = image_l[0:batch_size] label_l = label_l[0:batch_size] image_u = image_u[0:batch_size] label_u = label_u[0:batch_size] else: batch_size = image_l.size(0) data_time.update(time.time() - end) # for the labeled part, do classification and mixup image_l = image_l.float().cuda() label_l = label_l.long().cuda() label_onehot_l = torch.zeros(batch_size, discrete_latent_dim).cuda().scatter_(1, label_l.view(-1, 1), 1) reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l = model(image_l, disc_label=label_l) reconstruct_loss_l, continuous_prior_kl_loss_l, disc_prior_kl_loss_l = elbo_criterion(image_l, reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l) prior_kl_loss_l = kl_beta_c * torch.abs(continuous_prior_kl_loss_l - cmi) + kl_beta_d * torch.abs( disc_prior_kl_loss_l - dmi) elbo_loss_l = reconstruct_loss_l + prior_kl_loss_l disc_posterior_kl_loss_l = cls_criterion(disc_log_alpha_l, label_onehot_l) loss_supervised = ew * elbo_loss_l + disc_posterior_kl_loss_l loss_supervised.backward() # for the unlabeled part, do classification and mixup image_u = image_u.float().cuda() label_u = label_u.long().cuda() reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u = model(image_u) # calculate the KL(q(y|X)||p(y|X)) with torch.no_grad(): label_smooth_u = torch.zeros(batch_size, discrete_latent_dim).cuda().scatter_(1, label_u.view(-1, 1), 1 - 0.001 - 0.001 / ( discrete_latent_dim - 1)) label_smooth_u = label_smooth_u + torch.ones(label_smooth_u.size()).cuda() * 0.001 / (discrete_latent_dim - 1) disc_alpha_u = torch.exp(disc_log_alpha_u) inference_kl = disc_alpha_u * disc_log_alpha_u - disc_alpha_u * torch.log(label_smooth_u) kl_inferences.update(float(torch.sum(inference_kl) / batch_size), batch_size) reconstruct_loss_u, continuous_prior_kl_loss_u, disc_prior_kl_loss_u = elbo_criterion(image_u, reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u) prior_kl_loss_u = kl_beta_c * torch.abs(continuous_prior_kl_loss_u - cmi) + kl_beta_d * torch.abs( disc_prior_kl_loss_u - dmi) elbo_loss_u = reconstruct_loss_u + prior_kl_loss_u loss_unsupervised = ew * elbo_loss_u loss_unsupervised.backward() optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( epoch, i + 1, len(train_dloader_u), batch_time=batch_time, data_time=data_time) print(train_text) writer.add_scalar(tag="Train/KL_Inference", scalar_value=kl_inferences.avg, global_step=epoch + 1) # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images if epoch % args.reconstruct_freq == 0: with torch.no_grad(): image = utils.make_grid(image_u[:4, ...], nrow=2) reconstruct_image = utils.make_grid(torch.sigmoid(reconstruction_u[:4, ...]), nrow=2) writer.add_image(tag="Train/Raw_Image", img_tensor=image, global_step=epoch + 1) writer.add_image(tag="Train/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1)
def test(valid_dloader, test_dloader, model, criterion, epoch, writer, num_classes): model.eval() # calculate index for valid dataset losses = AverageMeter() all_score = [] all_label = [] for i, (image, label) in enumerate(valid_dloader): image = image.float().cuda() label = label.long().cuda() with torch.no_grad(): cls_result = model(image) label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_( 1, label.view(-1, 1), 1) loss = criterion(cls_result, label) losses.update(float(loss.item()), image.size(0)) # here we add the all score and all label into one list all_score.append(torch.softmax(cls_result, dim=1)) # turn label into one-hot code all_label.append(label_onehot) writer.add_scalar(tag="Valid/cls_loss", scalar_value=losses.avg, global_step=epoch + 1) all_score = torch.cat(all_score, dim=0).detach() all_label = torch.cat(all_label, dim=0).detach() _, y_true = torch.topk(all_label, k=1, dim=1) _, y_pred = torch.topk(all_score, k=5, dim=1) top_1_accuracy = float( torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar(tag="Valid/top 1 accuracy", scalar_value=top_1_accuracy, global_step=epoch + 1) if args.dataset == "Cifar100": writer.add_scalar(tag="Valid/top 5 accuracy", scalar_value=top_5_accuracy, global_step=epoch + 1) # calculate index for test dataset losses = AverageMeter() all_score = [] all_label = [] # don't use roc # roc_list = [] for i, (image, label) in enumerate(test_dloader): image = image.float().cuda() label = label.long().cuda() with torch.no_grad(): cls_result = model(image) label_onehot = torch.zeros(label.size(0), num_classes).cuda().scatter_( 1, label.view(-1, 1), 1) loss = criterion(cls_result, label) losses.update(float(loss.item()), image.size(0)) # here we add the all score and all label into one list all_score.append(torch.softmax(cls_result, dim=1)) # turn label into one-hot code all_label.append(label_onehot) writer.add_scalar(tag="Test/cls_loss", scalar_value=losses.avg, global_step=epoch + 1) all_score = torch.cat(all_score, dim=0).detach() all_label = torch.cat(all_label, dim=0).detach() _, y_true = torch.topk(all_label, k=1, dim=1) _, y_pred = torch.topk(all_score, k=5, dim=1) # don't use roc auc # all_score = all_score.cpu().numpy() # all_label = all_label.cpu().numpy() # for i in range(num_classes): # roc_list.append(roc_auc_score(all_label[:, i], all_score[:, i])) # ap_list.append(average_precision_score(all_label[:, i], all_score[:, i])) # calculate accuracy by hand top_1_accuracy = float( torch.sum(y_true == y_pred[:, :1]).item()) / y_true.size(0) top_5_accuracy = float(torch.sum(y_true == y_pred).item()) / y_true.size(0) writer.add_scalar(tag="Test/top 1 accuracy", scalar_value=top_1_accuracy, global_step=epoch + 1) if args.dataset == "Cifar100": writer.add_scalar(tag="Test/top 5 accuracy", scalar_value=top_5_accuracy, global_step=epoch + 1) # writer.add_scalar(tag="Test/mean RoC", scalar_value=mean(roc_list), global_step=epoch + 1) return top_1_accuracy
def train(train_dloader, model, elbo_criterion, optimizer, epoch, writer, kl_disc_criterion=None, kl_norm_criterion=None): batch_time = AverageMeter() data_time = AverageMeter() reconstruct_losses = AverageMeter() kl_losses = AverageMeter() elbo_losses = AverageMeter() mse_losses = AverageMeter() model.train() end = time.time() optimizer.zero_grad() # mutual information cmi = args.cmi * alpha_schedule(epoch, args.akb, 1, strategy="exp") dmi = args.dmi * alpha_schedule(epoch, args.akb, 1, strategy="exp") kl_beta = alpha_schedule(epoch, args.akb, args.kbm) posterior_weight = alpha_schedule(epoch, args.apw, args.pwm) # mixup parameters: if args.mixup: mixup_posterior_kl_losses = AverageMeter() # mixup_prior_kl_losses = AverageMeter() # mixup_elbo_losses = AverageMeter() # mixup_reconstruct_losses = AverageMeter() print( "Begin {} Epoch Training, CMI:{:.4f} DMI:{:.4f} KL-Beta{:.4f}".format( epoch + 1, cmi, dmi, kl_beta)) for i, (image, *_) in enumerate(train_dloader): data_time.update(time.time() - end) image = image.float().cuda() batch_size = image.size(0) reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model( image) reconstruct_loss, continuous_kl_loss, disc_kl_loss = elbo_criterion( image, reconstruction, norm_mean, norm_log_sigma, disc_log_alpha) kl_loss = torch.abs(continuous_kl_loss - cmi) + torch.abs(disc_kl_loss - dmi) elbo_loss = reconstruct_loss + kl_beta * kl_loss elbo_loss.backward() if args.mixup: with torch.no_grad(): mixed_image, mixed_z_mean, mixed_z_sigma, mixed_disc_alpha, lam = mixup_vae_data( image, norm_mean, norm_log_sigma, disc_log_alpha, alpha=args.ma) mixed_reconstruction, norm_mean, norm_log_sigma, disc_log_alpha, *_ = model( mixed_image) # continuous_kl_posterior_loss = kl_norm_criterion(norm_mean, norm_log_sigma, z_mean_gt=mixed_z_mean, # z_sigma_gt=mixed_z_sigma) disc_kl_posterior_loss = kl_disc_criterion(disc_log_alpha, mixed_disc_alpha) continuous_kl_posterior_loss = (F.mse_loss(norm_mean, mixed_z_mean, reduction="sum") + \ F.mse_loss(torch.exp(norm_log_sigma), mixed_z_sigma, reduction="sum")) / batch_size # disc_kl_posterior_loss = F.mse_loss(torch.exp(disc_log_alpha), mixed_disc_alpha, # reduction="sum") / batch_size mixup_kl_posterior_loss = posterior_weight * ( continuous_kl_posterior_loss + disc_kl_posterior_loss) mixup_kl_posterior_loss.backward() mixup_posterior_kl_losses.update(float(mixup_kl_posterior_loss), mixed_image.size(0)) elbo_losses.update(float(elbo_loss), image.size(0)) reconstruct_losses.update(float(reconstruct_loss), image.size(0)) kl_losses.update(float(kl_loss), image.size(0)) # calculate mse_losses if we use bce reconstruction loss if args.br: mse_loss = F.mse_loss(torch.sigmoid(reconstruction.detach()), image.detach(), reduction="sum") / (2 * image.size(0) * (args.x_sigma**2)) mse_losses.update(float(mse_loss), image.size(0)) optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'ELBO Loss {elbo_loss.val:.4f} ({elbo_loss.avg:.4f})\t' \ 'Reconstruct Loss {reconstruct_loss.val:.4f} ({reconstruct_loss.avg:.4f})\t' \ 'KL Loss {kl_loss.val:.4f} ({kl_loss.avg:.4f})\t'.format( epoch, i + 1, len(train_dloader), batch_time=batch_time, data_time=data_time, elbo_loss=elbo_losses, reconstruct_loss=reconstruct_losses, kl_loss=kl_losses) print(train_text) if args.mixup: # writer.add_scalar(tag="Train/Mixup-ELBO", scalar_value=mixup_elbo_losses.avg, global_step=epoch + 1) # writer.add_scalar(tag="Train/Mixup-Reconstruct", scalar_value=mixup_reconstruct_losses.avg, # global_step=epoch + 1) # writer.add_scalar(tag="Train/Mixup-KL-Prior", scalar_value=mixup_prior_kl_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Mixup-KL-Posterior", scalar_value=mixup_posterior_kl_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/ELBO", scalar_value=elbo_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Reconstruct", scalar_value=reconstruct_losses.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/KL", scalar_value=kl_losses.avg, global_step=epoch + 1) if args.br: writer.add_scalar(tag="Train/MSE", scalar_value=mse_losses.avg, global_step=epoch) # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images if epoch % args.reconstruct_freq == 0: with torch.no_grad(): image = utils.make_grid(image[:4, ...], nrow=2) reconstruct_image = utils.make_grid(torch.sigmoid( reconstruction[:4, ...]), nrow=2) writer.add_image(tag="Train/Raw_Image", img_tensor=image, global_step=epoch + 1) writer.add_image(tag="Train/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1) return elbo_losses.avg, reconstruct_losses.avg, kl_losses.avg
def train(train_dloader, generator, discriminator, g_optimizer, d_optimizer, criterion, writer, epoch): batch_time = AverageMeter() data_time = AverageMeter() generator_loss = AverageMeter() discriminator_real_loss = AverageMeter() discriminator_fake_loss = AverageMeter() dx_record = AverageMeter() dgz1_record = AverageMeter() dgz2_record = AverageMeter() generator.train() discriminator.train() end = time.time() g_optimizer.zero_grad() d_optimizer.zero_grad() dx_equilibrium = args.dx_equilibrium * (args.decay_equilibrium) ** epoch dgz1_equilibrium = args.dgz1_equilibrium * (args.decay_equilibrium) ** epoch dgz2_equilibrium = args.dgz2_equilibrium * (args.decay_equilibrium) ** epoch for i, (real_image, index, *_) in enumerate(train_dloader): data_time.update(time.time() - end) real_image = real_image.cuda() batch_size = real_image.size(0) # create noise for generator noise = torch.randn(batch_size, args.latent_dim, 1, 1).cuda() # create image label real_label = torch.full((batch_size,), args.real_label).cuda() # use discriminator to distinguish the real images output = discriminator(real_image).view(-1) # calculate d_x d_x = output.mean().item() # calculate the discriminator loss in real image d_loss_real = criterion(output, real_label) # equilibrium strategy if d_x <= 0.5 + dx_equilibrium: d_loss_real.backward() # use discriminator to distinguish the fake images fake = generator(noise) # here we only train discriminator, so we use fake.detach() output = discriminator(fake.detach()).view(-1) # calculate d_gz_1 d_gz_1 = output.mean().item() # create fake label fake_label = torch.full((batch_size,), args.fake_label).cuda() # calculate the discriminator loss in fake image d_loss_fake = criterion(output, fake_label) # equilibrium strategy if d_gz_1 >= 0.5 - dgz1_equilibrium: d_loss_fake.backward() # optimize discriminator d_optimizer.step() # here we train generator to make their generator image looks more real # one trick for generator is to use max log(D) instead of min log(1-D) g_label = torch.full((batch_size,), args.real_label).cuda() output = discriminator(fake).view(-1) # calculate d_gz_2 d_gz_2 = output.mean().item() # calculate the g_loss g_loss = criterion(output, g_label) if d_gz_2 <= 0.5 - dgz2_equilibrium: g_loss.backward() g_optimizer.step() # zero grad each optimizer d_optimizer.zero_grad() g_optimizer.zero_grad() # update d loss and g loss generator_loss.update(float(g_loss), batch_size) discriminator_fake_loss.update(float(d_loss_fake), batch_size) discriminator_real_loss.update(float(d_loss_real), batch_size) dx_record.update(float(d_x), batch_size) dgz1_record.update(float(d_gz_1), batch_size) dgz2_record.update(float(d_gz_2), batch_size) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'D(x) [{3:.4f}] D(G(z)) [{4:.4f}/{5:.4f}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Discriminator Real Loss {drl.val:.4f} ({drl.avg:.4f})\t' \ 'Discriminator Fake Loss {dfl.val:.4f} ({dfl.avg:.4f})\t' \ 'Generator Loss {gl.val:.4f} ({gl.avg:.4f})\t'.format( epoch, i + 1, len(train_dloader), d_x, d_gz_1, d_gz_2, batch_time=batch_time, data_time=data_time, drl=discriminator_real_loss, dfl=discriminator_fake_loss, gl=generator_loss) print(train_text) writer.add_scalar(tag="DCGAN/DRL", scalar_value=discriminator_real_loss.avg, global_step=epoch + 1) writer.add_scalar(tag="DCGAN/DFL", scalar_value=discriminator_fake_loss.avg, global_step=epoch + 1) writer.add_scalar(tag="DCGAN/GL", scalar_value=generator_loss.avg, global_step=epoch + 1) writer.add_scalar(tag="DCGAN/dx", scalar_value=dx_record.avg, global_step=epoch + 1) writer.add_scalar(tag="DCGAN/dgz1", scalar_value=dgz1_record.avg, global_step=epoch + 1) writer.add_scalar(tag="DCGAN/dgz2", scalar_value=dgz2_record.avg, global_step=epoch + 1) # in the train end, we want to add some real images and fake images real_image = utils.make_grid(real_image[:16, ...], nrow=4) writer.add_image(tag="Real_Image", img_tensor=(real_image * 0.5) + 0.5, global_step=epoch + 1) noise = torch.randn(16, args.latent_dim, 1, 1).cuda() fake_image = utils.make_grid(generator(noise), nrow=4) writer.add_image(tag="Fake_Image", img_tensor=(fake_image * 0.5) + 0.5, global_step=epoch + 1)
def train(train_dloader_u, train_dloader_l, model, elbo_criterion, cls_criterion, optimizer, epoch, writer, discrete_latent_dim): batch_time = AverageMeter() data_time = AverageMeter() kl_inferences = AverageMeter() model.train() end = time.time() optimizer.zero_grad() # mutual information cmi = alpha_schedule(epoch, args.akb, args.cmi) dmi = alpha_schedule(epoch, args.akb, args.dmi) # elbo part weight ew = alpha_schedule(epoch, args.aew, args.ewm) # mixup parameters kl_beta_c = alpha_schedule(epoch, args.akb, args.kbmc) kl_beta_d = alpha_schedule(epoch, args.akb, args.kbmd) pwm = alpha_schedule(epoch, args.apw, args.pwm) # unsupervised cls weight ucw = alpha_schedule(epoch, round(args.wmf * args.epochs), args.wrd) for i, ((image_l, label_l), (image_u, label_u)) in enumerate( zip(cycle(train_dloader_l), train_dloader_u)): batch_size_l = image_l.size(0) batch_size_u = image_u.size(0) data_time.update(time.time() - end) # for the labeled part, do classification and mixup image_l = image_l.float().cuda() label_l = label_l.long().cuda() label_onehot_l = torch.zeros(batch_size_l, discrete_latent_dim).cuda().scatter_( 1, label_l.view(-1, 1), 1) reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l = model( image_l, disc_label=label_l) reconstruct_loss_l, continuous_prior_kl_loss_l, disc_prior_kl_loss_l = elbo_criterion( image_l, reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l) prior_kl_loss_l = kl_beta_c * torch.abs(continuous_prior_kl_loss_l - cmi) + kl_beta_d * torch.abs( disc_prior_kl_loss_l - dmi) elbo_loss_l = reconstruct_loss_l + prior_kl_loss_l # do optimal transport estimation with torch.no_grad(): smoothed_image_l, smoothed_z_mean_l, smoothed_z_sigma_l, smoothed_disc_alpha_l, smoothed_label_l, smoothed_lambda_l = \ label_smoothing( image_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l, epsilon=args.epsilon, disc_label=label_l) smoothed_label_onehot_l = torch.zeros( batch_size_l, discrete_latent_dim).cuda().scatter_( 1, smoothed_label_l.view(-1, 1), 1) smoothed_reconstruction_l, smoothed_norm_mean_l, smoothed_norm_log_sigma_l, smoothed_disc_log_alpha_l, *_ = model( smoothed_image_l, True, label_l, smoothed_label_l, smoothed_lambda_l) disc_posterior_kl_loss_l = smoothed_lambda_l * cls_criterion( smoothed_disc_log_alpha_l, label_onehot_l) + (1 - smoothed_lambda_l) * cls_criterion( smoothed_disc_log_alpha_l, smoothed_label_onehot_l) continuous_posterior_kl_loss_l = (F.mse_loss(smoothed_norm_mean_l, smoothed_z_mean_l, reduction="sum") + \ F.mse_loss(torch.exp(smoothed_norm_log_sigma_l), smoothed_z_sigma_l, reduction="sum")) / batch_size_l elbo_loss_l = elbo_loss_l + kl_beta_c * pwm * continuous_posterior_kl_loss_l loss_supervised = ew * elbo_loss_l + disc_posterior_kl_loss_l loss_supervised.backward() # for the unlabeled part, do classification and mixup image_u = image_u.float().cuda() label_u = label_u.long().cuda() reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u = model( image_u) # calculate the KL(q_y_x|p_y_x) with torch.no_grad(): label_smooth_u = torch.zeros( batch_size_u, discrete_latent_dim).cuda().scatter_( 1, label_u.view(-1, 1), 1 - 0.001 - 0.001 / (discrete_latent_dim - 1)) label_smooth_u = label_smooth_u + torch.ones(label_smooth_u.size( )).cuda() * 0.001 / (discrete_latent_dim - 1) disc_alpha_u = torch.exp(disc_log_alpha_u) inference_kl = disc_alpha_u * disc_log_alpha_u - disc_alpha_u * torch.log( label_smooth_u) kl_inferences.update(float(torch.sum(inference_kl) / batch_size_u), batch_size_u) reconstruct_loss_u, continuous_prior_kl_loss_u, disc_prior_kl_loss_u = elbo_criterion( image_u, reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u) prior_kl_loss_u = kl_beta_c * torch.abs(continuous_prior_kl_loss_u - cmi) + kl_beta_d * torch.abs( disc_prior_kl_loss_u - dmi) elbo_loss_u = reconstruct_loss_u + prior_kl_loss_u # do mixup part with torch.no_grad(): mixed_image_u, mixed_z_mean_u, mixed_z_sigma_u, mixed_disc_alpha_u, lam_u = \ mixup_vae_data( image_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u, optimal_match=args.om) mixed_reconstruction_u, mixed_norm_mean_u, mixed_norm_log_sigma_u, mixed_disc_log_alpha_u, *_ = model( mixed_image_u) disc_posterior_kl_loss_u = cls_criterion(mixed_disc_log_alpha_u, mixed_disc_alpha_u) continuous_posterior_kl_loss_u = (F.mse_loss(mixed_norm_mean_u, mixed_z_mean_u, reduction="sum") + \ F.mse_loss(torch.exp(mixed_norm_log_sigma_u), mixed_z_sigma_u, reduction="sum")) / batch_size_u elbo_loss_u = elbo_loss_u + kl_beta_c * pwm * continuous_posterior_kl_loss_u loss_unsupervised = ew * elbo_loss_u + ucw * disc_posterior_kl_loss_u loss_unsupervised.backward() optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'.format( epoch, i + 1, len(train_dloader_u), batch_time=batch_time, data_time=data_time) print(train_text) # record unlabeled part loss writer.add_scalar(tag="Train/KL_Inference", scalar_value=kl_inferences.avg, global_step=epoch + 1) # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images if epoch % args.reconstruct_freq == 0: with torch.no_grad(): image = utils.make_grid(image_u[:4, ...], nrow=2) reconstruct_image = utils.make_grid(torch.sigmoid( reconstruction_u[:4, ...]), nrow=2) writer.add_image(tag="Train/Raw_Image", img_tensor=image, global_step=epoch + 1) writer.add_image(tag="Train/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1)
def train(train_dloader_u, train_dloader_l, model, elbo_criterion, cls_criterion, optimizer, epoch, writer, discrete_latent_dim, kl_norm_criterion=None, kl_disc_criterion=None): batch_time = AverageMeter() data_time = AverageMeter() # train_dloader_u part reconstruct_losses_u = AverageMeter() continuous_prior_kl_losses_u = AverageMeter() discrete_prior_kl_losses_u = AverageMeter() elbo_losses_u = AverageMeter() mse_losses_u = AverageMeter() continuous_posterior_kl_losses_u = AverageMeter() discrete_posterior_kl_losses_u = AverageMeter() # train_dloader_l part reconstruct_losses_l = AverageMeter() continuous_prior_kl_losses_l = AverageMeter() discrete_prior_kl_losses_l = AverageMeter() elbo_losses_l = AverageMeter() mse_losses_l = AverageMeter() continuous_posterior_kl_losses_l = AverageMeter() discrete_posterior_kl_losses_l = AverageMeter() model.train() end = time.time() optimizer.zero_grad() # mutual information cmi = alpha_schedule(epoch, args.akb, args.cmi, strategy="exp") dmi = alpha_schedule(epoch, args.akb, args.dmi, strategy="exp") # elbo part weight ew = alpha_schedule(epoch, args.aew, args.ewm) # mixup parameters kl_beta_c = alpha_schedule(epoch, args.akb, args.kbmc) kl_beta_d = alpha_schedule(epoch, args.akb, args.kbmd) pwm = alpha_schedule(epoch, args.apw, args.pwm) # unsupervised cls weight ucw = alpha_schedule(epoch, round(args.wmf * args.epochs), args.wrd) for i, ((image_l, label_l), (image_u, _)) in enumerate(zip(cycle(train_dloader_l), train_dloader_u)): if image_l.size(0) != image_u.size(0): batch_size = min(image_l.size(0), image_u.size(0)) image_l = image_l[0:batch_size] label_l = label_l[0:batch_size] image_u = image_u[0:batch_size] else: batch_size = image_l.size(0) data_time.update(time.time() - end) # for the labeled part, do classification and mixup image_l = image_l.float().cuda() label_l = label_l.long().cuda() label_onehot_l = torch.zeros(batch_size, discrete_latent_dim).cuda().scatter_( 1, label_l.view(-1, 1), 1) reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l = model( image_l, disc_label=label_l) reconstruct_loss_l, continuous_prior_kl_loss_l, disc_prior_kl_loss_l = elbo_criterion( image_l, reconstruction_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l) prior_kl_loss_l = kl_beta_c * torch.abs(continuous_prior_kl_loss_l - cmi) + kl_beta_d * torch.abs( disc_prior_kl_loss_l - dmi) elbo_loss_l = reconstruct_loss_l + prior_kl_loss_l reconstruct_losses_l.update(float(reconstruct_loss_l.detach().item()), batch_size) continuous_prior_kl_losses_l.update( float(continuous_prior_kl_loss_l.detach().item()), batch_size) discrete_prior_kl_losses_l.update( float(disc_prior_kl_loss_l.detach().item()), batch_size) # do optimal transport estimation with torch.no_grad(): mixed_image_l, mixed_z_mean_l, mixed_z_sigma_l, mixed_disc_alpha_l, label_mixup_l, lam_l = \ mixup_vae_data( image_l, norm_mean_l, norm_log_sigma_l, disc_log_alpha_l, alpha=args.mas, disc_label=label_l) label_mixup_onehot_l = torch.zeros( batch_size, discrete_latent_dim).cuda().scatter_(1, label_mixup_l.view(-1, 1), 1) mixed_reconstruction_l, mixed_norm_mean_l, mixed_norm_log_sigma_l, mixed_disc_log_alpha_l, *_ = model( mixed_image_l, mixup=True, disc_label=label_l, disc_label_mixup=label_mixup_l, mixup_lam=lam_l) disc_posterior_kl_loss_l = lam_l * cls_criterion( mixed_disc_log_alpha_l, label_onehot_l) + (1 - lam_l) * cls_criterion( mixed_disc_log_alpha_l, label_mixup_onehot_l) continuous_posterior_kl_loss_l = (F.mse_loss(mixed_norm_mean_l, mixed_z_mean_l, reduction="sum") + \ F.mse_loss(torch.exp(mixed_norm_log_sigma_l), mixed_z_sigma_l, reduction="sum")) / batch_size elbo_loss_l = elbo_loss_l + kl_beta_c * pwm * continuous_posterior_kl_loss_l elbo_losses_l.update(float(elbo_loss_l.detach().item()), batch_size) continuous_posterior_kl_losses_l.update( float(continuous_posterior_kl_loss_l.detach().item()), batch_size) discrete_posterior_kl_losses_l.update( float(disc_posterior_kl_loss_l.detach().item()), batch_size) if args.br: mse_loss_l = F.mse_loss(torch.sigmoid(reconstruction_l.detach()), image_l.detach(), reduction="sum") / (2 * batch_size * (args.x_sigma**2)) mse_losses_l.update(float(mse_loss_l), batch_size) loss_supervised = ew * elbo_loss_l + disc_posterior_kl_loss_l loss_supervised.backward() # for the unlabeled part, do classification and mixup image_u = image_u.float().cuda() reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u = model( image_u) reconstruct_loss_u, continuous_prior_kl_loss_u, disc_prior_kl_loss_u = elbo_criterion( image_u, reconstruction_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u) prior_kl_loss_u = kl_beta_c * torch.abs(continuous_prior_kl_loss_u - cmi) + kl_beta_d * torch.abs( disc_prior_kl_loss_u - dmi) elbo_loss_u = reconstruct_loss_u + prior_kl_loss_u reconstruct_losses_u.update(float(reconstruct_loss_u.detach().item()), batch_size) continuous_prior_kl_losses_u.update( float(continuous_prior_kl_loss_u.detach().item()), batch_size) discrete_prior_kl_losses_u.update( float(disc_prior_kl_loss_u.detach().item()), batch_size) # do mixup part with torch.no_grad(): mixed_image_u, mixed_z_mean_u, mixed_z_sigma_u, mixed_disc_alpha_u, lam_u = \ mixup_vae_data( image_u, norm_mean_u, norm_log_sigma_u, disc_log_alpha_u, alpha=args.mau) mixed_reconstruction_u, mixed_norm_mean_u, mixed_norm_log_sigma_u, mixed_disc_log_alpha_u, *_ = model( mixed_image_u) disc_posterior_kl_loss_u = cls_criterion(mixed_disc_log_alpha_u, mixed_disc_alpha_u) continuous_posterior_kl_loss_u = (F.mse_loss(mixed_norm_mean_u, mixed_z_mean_u, reduction="sum") + \ F.mse_loss(torch.exp(mixed_norm_log_sigma_u), mixed_z_sigma_u, reduction="sum")) / batch_size elbo_loss_u = elbo_loss_u + kl_beta_c * pwm * continuous_posterior_kl_loss_u loss_unsupervised = ew * elbo_loss_u + ucw * disc_posterior_kl_loss_u elbo_losses_u.update(float(elbo_loss_u.detach().item()), batch_size) continuous_posterior_kl_losses_u.update( float(continuous_posterior_kl_loss_u.detach().item()), batch_size) discrete_posterior_kl_losses_u.update( float(disc_posterior_kl_loss_u.detach().item()), batch_size) loss_unsupervised.backward() if args.br: mse_loss_u = F.mse_loss(torch.sigmoid(reconstruction_u.detach()), image_u.detach(), reduction="sum") / (2 * batch_size * (args.x_sigma**2)) mse_losses_u.update(float(mse_loss_u), batch_size) optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: train_text = 'Epoch: [{0}][{1}/{2}]\t' \ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \ 'Cls Loss Labeled {cls_loss_l.val:.4f} ({cls_loss_l.avg:.4f})\t' \ 'Cls Loss Unlabeled {cls_loss_u.val:.4f} ({cls_loss_u.avg:.4f})\t' \ 'Continuous Prior KL Loss Labeled {cpk_loss_l.val:.4f} ({cpk_loss_l.avg:.4f})\t' \ 'Continuous Prior KL Loss Unlabeled {cpk_loss_u.val:.4f} ({cpk_loss_u.avg:.4f})\t'.format( epoch, i + 1, len(train_dloader_u), batch_time=batch_time, data_time=data_time, cls_loss_l=discrete_posterior_kl_losses_l, cls_loss_u=discrete_posterior_kl_losses_u, cpk_loss_l=continuous_prior_kl_losses_l, cpk_loss_u=continuous_prior_kl_losses_u) print(train_text) # record unlabeled part loss writer.add_scalar(tag="Train/ELBO_U", scalar_value=elbo_losses_u.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Reconstrut_U", scalar_value=reconstruct_losses_u.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Continuous_Prior_KL_U", scalar_value=continuous_prior_kl_losses_u.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Continuous_Posterior_KL_U", scalar_value=continuous_posterior_kl_losses_u.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Discrete_Prior_KL_U", scalar_value=discrete_prior_kl_losses_u.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Discrete_Posterior_KL_U", scalar_value=discrete_posterior_kl_losses_u.avg, global_step=epoch + 1) if args.br: writer.add_scalar(tag="Train/MSE_U", scalar_value=mse_losses_u.avg, global_step=epoch + 1) # record labeled part loss writer.add_scalar(tag="Train/ELBO_L", scalar_value=elbo_losses_l.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Reconstruct_L", scalar_value=reconstruct_losses_l.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Continuous_Prior_KL_L", scalar_value=continuous_prior_kl_losses_l.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Continuous_Posterior_KL_L", scalar_value=continuous_posterior_kl_losses_l.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Discrete_Prior_KL_L", scalar_value=discrete_prior_kl_losses_l.avg, global_step=epoch + 1) writer.add_scalar(tag="Train/Discrete_Posterior_KL_L", scalar_value=discrete_posterior_kl_losses_l.avg, global_step=epoch + 1) if args.br: writer.add_scalar(tag="Train/MSE_L", scalar_value=mse_losses_l.avg, global_step=epoch + 1) # after several epoch training, we add the image and reconstructed image into the image board, we just use 16 images if epoch % args.reconstruct_freq == 0: with torch.no_grad(): image = utils.make_grid(image_u[:4, ...], nrow=2) reconstruct_image = utils.make_grid(torch.sigmoid( reconstruction_u[:4, ...]), nrow=2) writer.add_image(tag="Train/Raw_Image", img_tensor=image, global_step=epoch + 1) writer.add_image(tag="Train/Reconstruct_Image", img_tensor=reconstruct_image, global_step=epoch + 1) return discrete_posterior_kl_losses_u.avg, discrete_posterior_kl_losses_l.avg
def valid(valid_dloader, model, criterion, epoch, writer): """ Here valid dataloader we may need to organize it with each study """ model.eval() # calculate score and label for different dataset atelectasis_score_dict = defaultdict(list) atelectasis_label_dict = defaultdict(list) cardiomegaly_score_dict = defaultdict(list) cardiomegaly_label_dict = defaultdict(list) consolidation_score_dict = defaultdict(list) consolidation_label_dict = defaultdict(list) edema_score_dict = defaultdict(list) edema_label_dict = defaultdict(list) pleural_effusion_score_dict = defaultdict(list) pleural_effusion_label_dict = defaultdict(list) # calculate index for valid dataset losses = AverageMeter() for idx, (image, index, image_name, label_weight, label) in enumerate(valid_dloader): image = image.float().cuda() label = label.long().cuda() label_weight = label_weight.cuda() with torch.no_grad(): prediction_list = model(image) loss = criterion(prediction_list, label_weight, label) losses.update(float(loss.item()), image.size(0)) for i, img_name in enumerate(image_name): study_name = re.match(r"(.*)patient(.*)\|(.*)", img_name).group(2) for j, prediction in enumerate(prediction_list): score = prediction[i, 1].item() item_label = label[i, j].item() if j == 0: atelectasis_label_dict[study_name].append(item_label) atelectasis_score_dict[study_name].append(score) elif j == 1: cardiomegaly_label_dict[study_name].append(item_label) cardiomegaly_score_dict[study_name].append(score) elif j == 2: consolidation_label_dict[study_name].append(item_label) consolidation_score_dict[study_name].append(score) elif j == 3: edema_label_dict[study_name].append(item_label) edema_score_dict[study_name].append(score) else: pleural_effusion_label_dict[study_name].append(item_label) pleural_effusion_score_dict[study_name].append(score) writer.add_scalar(tag="Valid/cls_loss", scalar_value=losses.avg, global_step=epoch + 1) # Calculate AUC ROC # Here we use the max method to get the score and label list of each study atelectasis_score, atelectasis_label = get_score_label_array_from_dict( atelectasis_score_dict, atelectasis_label_dict) cardiomegaly_score, cardiomegaly_label = get_score_label_array_from_dict( cardiomegaly_score_dict, cardiomegaly_label_dict) consolidation_score, consolidation_label = get_score_label_array_from_dict( consolidation_score_dict, consolidation_label_dict) edema_score, edema_label = get_score_label_array_from_dict( edema_score_dict, edema_label_dict) pleural_effusion_score, pleural_effusion_label = get_score_label_array_from_dict( pleural_effusion_score_dict, pleural_effusion_label_dict) atelectasis_auc = roc_auc_score(atelectasis_label, atelectasis_score) cardiomegaly_auc = roc_auc_score(cardiomegaly_label, cardiomegaly_score) consolidation_auc = roc_auc_score(consolidation_label, consolidation_score) edema_auc = roc_auc_score(edema_label, edema_score) pleural_effusion_auc = roc_auc_score(pleural_effusion_label, pleural_effusion_score) writer.add_scalar(tag="Valid/Atelectasis_AUC", scalar_value=atelectasis_auc, global_step=epoch) writer.add_scalar(tag="Valid/Cardiomegaly_AUC", scalar_value=cardiomegaly_auc, global_step=epoch) writer.add_scalar(tag="Valid/Consolidation_AUC", scalar_value=consolidation_auc, global_step=epoch) writer.add_scalar(tag="Valid/Edema_AUC", scalar_value=edema_auc, global_step=epoch) writer.add_scalar(tag="Valid/Pleural_Effusion_AUC", scalar_value=pleural_effusion_auc, global_step=epoch) return [ atelectasis_auc, cardiomegaly_auc, consolidation_auc, edema_auc, pleural_effusion_auc ]