def draw_in_tensorboard(writer, images, label_trg, i_iter, pred_main, pred_main_swarp, num_classes, type_): grid_image = make_grid(images[:3].clone().cpu().data, 3, normalize=True) writer.add_image(f'Image - {type_}', grid_image, i_iter) pred_main_cat = torch.cat((pred_main, pred_main_swarp), dim=-1) grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray( np.argmax(F.softmax(pred_main_cat).cpu().data[0].numpy().transpose(1, 2, 0), axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3, normalize=False, range=(0, 255)) writer.add_image(f'Prediction_main_swarp - {type_}', grid_image, i_iter) grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray(label_trg.cpu().squeeze(), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3, normalize=False, range=(0, 255)) writer.add_image(f'Labels_IAST - {type_}', grid_image, i_iter)
def draw_in_tensorboard(writer, images, i_iter, pred_main, num_classes, type_): grid_image = make_grid(images[:3].clone().cpu().data, 3, normalize=True) writer.add_image(f'Image - {type_}', grid_image, i_iter) grid_image = make_grid(torch.from_numpy( np.array( colorize_mask( np.asarray(np.argmax( F.softmax(pred_main).cpu().data[0].numpy().transpose( 1, 2, 0), axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3, normalize=False, range=(0, 255)) writer.add_image(f'Prediction - {type_}', grid_image, i_iter) output_sm = F.softmax(pred_main).cpu().data[0].numpy().transpose(1, 2, 0) output_ent = np.sum(-np.multiply(output_sm, np.log2(output_sm)), axis=2, keepdims=False) grid_image = make_grid(torch.from_numpy(output_ent), 3, normalize=True, range=(0, np.log2(num_classes))) writer.add_image(f'Entropy - {type_}', grid_image, i_iter)
def draw_in_tensorboard(writer, images, images_aug, i_iter, pred_main, pred_main_aug, type_): grid_image = make_grid(images[:3].clone().cpu().data, 3, normalize=True) writer.add_image(f'Image - {type_}', grid_image, i_iter) grid_image = make_grid(images_aug[:3].clone().cpu().data, 3, normalize=True) writer.add_image(f'Image_aug - {type_}', grid_image, i_iter) grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray( np.argmax(F.softmax(pred_main).cpu().data[0].numpy().transpose(1, 2, 0), axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3, normalize=False, range=(0, 255)) writer.add_image(f'Prediction - {type_}', grid_image, i_iter) grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray( np.argmax(F.softmax(pred_main_aug).cpu().data[0].numpy().transpose(1, 2, 0), axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3, normalize=False, range=(0, 255)) writer.add_image(f'Prediction_aug - {type_}', grid_image, i_iter)
def eval_best(cfg, models, device, test_loader, interp, fixed_test_size, verbose): # -------------------------------------------------------- # # codes to initialize wandb for storing logs on its cloud wandb.init(project='FDA_integration_to_INTRA_DA') for key, val in cfg.items(): wandb.config.update({key: val}) # -------------------------------------------------------- # assert len(models) == 1, 'Not yet supported multi models in this mode' assert osp.exists(cfg.TEST.SNAPSHOT_DIR[0]), 'SNAPSHOT_DIR is not found' start_iter = cfg.TEST.SNAPSHOT_STEP step = cfg.TEST.SNAPSHOT_STEP max_iter = cfg.TEST.SNAPSHOT_MAXITER cache_path = osp.join(cfg.TEST.SNAPSHOT_DIR[0], 'all_res.pkl') if osp.exists(cache_path): cache_path = pickle_load(cache_path) else: all_res = {} cur_best_miou = -1 cur_best_model = '' for i_iter in range(start_iter, max_iter + 1, step): restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0], f'model_{i_iter}.pth') if not osp.exists(restore_from): # continue if cfg.TEST.WAIT_MODEL: print('Waiting for model..!') while not osp.exists(restore_from): time.sleep(5) print("Evaluating model", restore_from) if i_iter not in all_res.keys(): load_checkpoint_for_evaluation(models[0], restore_from, device) # eval hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES)) # for index, batch in enumerate(test_loader): # image, _, _, name = batch test_iter = iter(test_loader) for index in tqdm(range(len(test_loader))): image, label, _, name = next(test_iter) if not fixed_test_size: interp = nn.Upsample(size=(label.shape[1], label.shape[2]), mode='bilinear', align_corners=True) with torch.no_grad(): pred_main = models[0](image.cuda(device))[1] output = interp(pred_main).cpu().data[0].numpy() output = output.transpose(1, 2, 0) output = np.argmax(output, axis=2) label = label.numpy()[0] hist += fast_hist(label.flatten(), output.flatten(), cfg.NUM_CLASSES) if verbose and index > 0 and index % 100 == 0: print('{:d} / {:d}: {:0.2f}'.format( index, len(test_loader), 100 * np.nanmean(per_class_iu(hist)))) inters_over_union_classes = per_class_iu(hist) all_res[i_iter] = inters_over_union_classes pickle_dump(all_res, cache_path) # -------------------------------------------------------- # # save logs at weight and biases IoU_classes = {} for idx in range(cfg.NUM_CLASSES): IoU_classes[test_loader.dataset.class_names[idx]] = round( inters_over_union_classes[idx] * 100, 2) wandb.log(IoU_classes, step=(i_iter)) wandb.log( { 'mIoU19': round( np.nanmean(inters_over_union_classes) * 100, 2) }, step=(i_iter)) wandb.log( { 'val_prediction': wandb.Image( colorize_mask(np.asarray( output, dtype=np.uint8)).convert('RGB')) }, step=(i_iter)) # -------------------------------------------------------- # else: inters_over_union_classes = all_res[i_iter] computed_miou = round(np.nanmean(inters_over_union_classes) * 100, 2) if cur_best_miou < computed_miou: cur_best_miou = computed_miou cur_best_model = restore_from print('\tCurrent mIoU:', computed_miou) print('\tCurrent best model:', cur_best_model) print('\tCurrent best mIoU:', cur_best_miou) wandb.log({'best mIoU': cur_best_miou}, step=(i_iter)) if verbose: display_stats(cfg, test_loader.dataset.class_names, inters_over_union_classes)
def eval_single(cfg, models, device, test_loader, interp, fixed_test_size, verbose): assert len(cfg.TEST.RESTORE_FROM) == len(models), 'Number of models are not matched' folder_path = cfg.TEST.RESTORE_FROM[0].split('/')[-2] folder_path = osp.join(result_root, folder_path, "eval_image") if not osp.exists(folder_path): os.makedirs(folder_path) for checkpoint, model in zip(cfg.TEST.RESTORE_FROM, models): load_checkpoint_for_evaluation(model, checkpoint, device) # eval hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES)) for index, batch in tqdm(enumerate(test_loader)): image, label, _, name = batch if not fixed_test_size: interp = nn.Upsample(size=(label.shape[1], label.shape[2]), mode='bilinear', align_corners=True) with torch.no_grad(): output = None for model, model_weight in zip(models, cfg.TEST.MODEL_WEIGHT): _, pred_main, pred_boundary = model(image.cuda(device)) output_ = interp(pred_main).cpu().data[0].numpy() output_pred = interp(pred_main) if output is None: output = model_weight * output_ else: output += model_weight * output_ domain = name[0].split('/')[-2] save_path = folder_path + "/" + domain + "_" + name[0].split('/')[-1].split('.')[0] # segmentation prediction save save_image(torch.from_numpy(np.array(colorize_mask( np.asarray(np.argmax(F.softmax(output_pred).cpu().data[0].numpy().transpose(1, 2, 0), axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), save_path+"_3_seg.png", 3, normalize=False, range=(0, 255)) # boundary prediction save pred_boundary = F.interpolate(pred_boundary, label.shape[1:], mode='bilinear') save_image(pred_boundary.clone(), save_path+"_1_boundary.png", normalize=True) # red boundary visualize vis_red_boundary(save_path, pred_boundary.clone()) # binary boundary prediction save save_image_binary(save_path, pred_boundary.clone(), threshold=0.5) # color label save save_image(torch.from_numpy(np.array( colorize_mask(np.asarray(label.squeeze(0).numpy(), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), save_path+"_2_label.png", 3, normalize=False) assert output is not None, 'Output is None' output = output.transpose(1, 2, 0) output = np.argmax(output, axis=2) label = label.numpy()[0] hist += fast_hist(label.flatten(), output.flatten(), cfg.NUM_CLASSES) inters_over_union_classes = per_class_iu(hist) print(f'mIoU = \t{round(np.nanmean(inters_over_union_classes) * 100, 2)}') if verbose: display_stats(cfg, test_loader.dataset.class_names, inters_over_union_classes)
def train_advent(model, trainloader, targetloader, cfg, args): ''' UDA training with advent ''' # Create the model and start the training. # pdb.set_trace() input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32) SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1)) device = cfg.GPU_ID num_classes = cfg.NUM_CLASSES viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR) if viz_tensorboard: writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR) # -------------------------------------------------------- # # codes to initialize wandb for storing logs on its cloud wandb.init(project='FDA_integration_to_INTRA_DA') wandb.config.update(args) for key, val in cfg.items(): wandb.config.update({key: val}) wandb.watch(model) # -------------------------------------------------------- # # SEGMNETATION NETWORK model.train() model.to(device) cudnn.benchmark = True cudnn.enabled = True # DISCRIMINATOR NETWORK # feature-level d_aux = get_fc_discriminator(num_classes=num_classes) d_aux.train() d_aux.to(device) # restore_from = cfg.TRAIN.RESTORE_FROM_aux # print("Load Discriminator:", restore_from) # load_checkpoint_for_evaluation(d_aux, restore_from, device) # seg maps, i.e. output, level d_main = get_fc_discriminator(num_classes=num_classes) d_main.train() d_main.to(device) # restore_from = cfg.TRAIN.RESTORE_FROM_main # print("Load Discriminator:", restore_from) # load_checkpoint_for_evaluation(d_main, restore_from, device) # OPTIMIZERS # segnet's optimizer optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE), lr=cfg.TRAIN.LEARNING_RATE, momentum=cfg.TRAIN.MOMENTUM, weight_decay=cfg.TRAIN.WEIGHT_DECAY) # discriminators' optimizers optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D, betas=(0.9, 0.99)) # interpolate output segmaps interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear', align_corners=True) interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training source_label = 0 target_label = 1 trainloader_iter = enumerate(trainloader) targetloader_iter = enumerate(targetloader) for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)): # reset optimizers optimizer.zero_grad() optimizer_d_aux.zero_grad() optimizer_d_main.zero_grad() # adapt LR if needed adjust_learning_rate(optimizer, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg) adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg) # UDA Training # only train segnet. Don't accumulate grads in disciminators for param in d_aux.parameters(): param.requires_grad = False for param in d_main.parameters(): param.requires_grad = False _, batch = trainloader_iter.__next__() images_source, labels, _, _ = batch _, batch = targetloader_iter.__next__() images, _, _, _ = batch # ----------------------------------------------------------------# B, C, H, W = images_source.shape mean_images_source = SRC_IMG_MEAN.repeat(B, 1, H, W) mean_images = SRC_IMG_MEAN.repeat(B, 1, H, W) if args.FDA_mode == 'on': # normalize the source and target image images_source -= mean_images_source images -= mean_images elif args.FDA_mode == 'off': # Keep source and target images as they are # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off' images_source = images_source images = images else: raise KeyError() # ----------------------------------------------------------------# # debug: # labels=labels.numpy() # from matplotlib import pyplot as plt # import numpy as np # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show() # train on source pred_src_aux, pred_src_main = model(images_source.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = interp(pred_src_aux) loss_seg_src_aux = loss_calc(pred_src_aux, labels, device) else: loss_seg_src_aux = 0 pred_src_main = interp(pred_src_main) loss_seg_src_main = loss_calc(pred_src_main, labels, device) # pdb.set_trace() loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux) loss.backward() # adversarial training ot fool the discriminator pred_trg_aux, pred_trg_main = model(images.cuda(device)) if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = interp_target(pred_trg_aux) d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_adv_trg_aux = bce_loss(d_out_aux, source_label) else: loss_adv_trg_aux = 0 pred_trg_main = interp_target(pred_trg_main) d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_adv_trg_main = bce_loss(d_out_main, source_label) loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main + cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux) loss = loss loss.backward() # Train discriminator networks # enable training mode on discriminator networks for param in d_aux.parameters(): param.requires_grad = True for param in d_main.parameters(): param.requires_grad = True # train with source if cfg.TRAIN.MULTI_LEVEL: pred_src_aux = pred_src_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux))) loss_d_aux = bce_loss(d_out_aux, source_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() pred_src_main = pred_src_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main))) loss_d_main = bce_loss(d_out_main, source_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() # train with target if cfg.TRAIN.MULTI_LEVEL: pred_trg_aux = pred_trg_aux.detach() d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) loss_d_aux = bce_loss(d_out_aux, target_label) loss_d_aux = loss_d_aux / 2 loss_d_aux.backward() else: loss_d_aux = 0 pred_trg_main = pred_trg_main.detach() d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main))) loss_d_main = bce_loss(d_out_main, target_label) loss_d_main = loss_d_main / 2 loss_d_main.backward() optimizer.step() if cfg.TRAIN.MULTI_LEVEL: optimizer_d_aux.step() optimizer_d_main.step() if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0: print('taking snapshot ...') print('exp =', cfg.TRAIN.SNAPSHOT_DIR) snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR) torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth') torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth') torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth') if i_iter >= cfg.TRAIN.EARLY_STOP - 1: break sys.stdout.flush() # Visualize with tensorboard if viz_tensorboard: # ----------------------------------------------------------------# if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1: current_losses = { 'loss_seg_src_aux': loss_seg_src_aux, 'loss_seg_src_main': loss_seg_src_main, 'loss_adv_trg_aux': loss_adv_trg_aux, 'loss_adv_trg_main': loss_adv_trg_main, 'loss_d_aux': loss_d_aux, 'loss_d_main': loss_d_main } print_losses(current_losses, i_iter) log_losses_tensorboard(writer, current_losses, i_iter) draw_in_tensorboard(writer, images + mean_images, i_iter, pred_trg_main, num_classes, 'T') draw_in_tensorboard(writer, images_source + mean_images_source, i_iter, pred_src_main, num_classes, 'S') wandb.log({'loss': current_losses}, step=(i_iter + 1)) if i_iter % (cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE ) * 25 - 1: # for every 2500 iteration wandb.log( {'source': wandb.Image(torch.flip(images_source+mean_images_source, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), \ 'target': wandb.Image(torch.flip(images+mean_images, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), 'pesudo label': wandb.Image(np.asarray(colorize_mask(np.asarray(labels.cpu().data.numpy().transpose(1,2,0).reshape((512,1024)), dtype=np.uint8)).convert('RGB')) )}, step=(i_iter + 1))