Ejemplo n.º 1
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    src_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
    ])
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    src_dataset = GTA5DataSetLMDB(
        args.data_dir, args.data_list,
        joint_transform=joint_transform,
        transform=src_input_transform, 
        target_transform=target_transform,
    )
    src_loader = data.DataLoader(
        src_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, drop_last=True
    )
    tgt_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform, 
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(
        tgt_dataset, batch_size=args.batch_size, shuffle=True, 
        num_workers=args.num_workers, pin_memory=True, drop_last=True
    )

    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_val, args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, drop_last=False
    )

    style_trans = StyleTrans(args)
    style_trans.train(src_loader, tgt_loader, val_loader, writer)

    writer.close()
Ejemplo n.º 2
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    loader = data.DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    net = resnet101_ibn_a_deeplab(args.model_path_prefix,
                                  n_classes=args.n_classes)
    # optimizer = get_seg_optimizer(net, args)
    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net)
    criterion = torch.nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=args.ignore_index)

    num_batches = len(loader)
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            show_fig = (batch_index + 1) % args.show_img_freq == 0
            iteration = batch_index + 1 + epoch * num_batches

            # poly_lr_scheduler(
            #     optimizer=optimizer,
            #     init_lr=args.learning_rate,
            #     iter=iteration - 1,
            #     lr_decay_iter=args.lr_decay,
            #     max_iter=args.num_epoch*num_batches,
            #     power=args.poly_power,
            # )

            net.train()
            # net.module.freeze_bn()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time() - tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if show_fig:
                base_lr = optimizer.param_groups[0]["lr"]
                output = torch.argmax(output, dim=1).detach()[0, ...].cpu()
                fig, axes = plt.subplots(2, 1, figsize=(12, 14))
                axes = axes.flat
                axes[0].imshow(colorize_mask(output.numpy()))
                axes[0].set_title(name[0])
                axes[1].imshow(colorize_mask(label[0, ...].numpy()))
                axes[1].set_title(f'seg_true_{base_lr:.6f}')
                writer.add_figure('A_seg', fig, iteration)

        mean_iu = test_miou(net, val_loader, upsample,
                            './ae_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix,
                         f'{epoch:d}_{mean_iu*100:.0f}.pth'))

    writer.close()
Ejemplo n.º 3
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()
Ejemplo n.º 4
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    student_net = FCN8s(args.num_classes, args.model_path_prefix)
    student_net = torch.nn.DataParallel(student_net)

    student_net = student_net.cuda()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    optimizer = optim.SGD(student_net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     student_net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    student_params = list(student_net.parameters())

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # src_criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    src_criterion = torch.nn.CrossEntropyLoss(ignore_index=255,
                                              size_average=False)

    num_batches = len(src_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            student_net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = student_net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = src_criterion(src_output, src_label)
            cls_loss_value /= src_images.shape[0]

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')

        miu = test_miou(student_net, tgt_val_loader, upsample,
                        './dataset/info.json')
        if miu > highest:
            torch.save(student_net.module.state_dict(),
                       osp.join(args.snapshot_dir, f'final_fcn.pth'))
            highest = miu
            print('>' * 50 + f'save highest with {miu:.2%}')
Ejemplo n.º 5
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    if args.bn_sync:
        print('Using Sync BN')
        deeplabv3.BatchNorm2d = partial(InPlaceABNSync, activation='none')
    net = get_deeplabV3(args.num_classes, args.model_path_prefix)
    if not args.bn_sync:
        net.freeze_bn()
    net = torch.nn.DataParallel(net)

    net = net.cuda()

    mean_std = ([104.00698793, 116.66876762, 122.67891434], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    # freeze bn
    for module in net.module.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = False
    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad,
                             net.module.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, size_average=True
    # )
    criterion = CriterionDSN(ignore_index=255,
                             # size_average=False
                             )

    num_batches = len(src_loader)
    max_iter = args.iterations
    i_iter = 0
    highest_miu = 0

    while True:

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            i_iter += 1
            lr = adjust_learning_rate(args, optimizer, i_iter, max_iter)
            net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = criterion(src_output, src_label)

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            src_output = torch.nn.functional.upsample(input=src_output[0],
                                                      size=(h, w),
                                                      mode='bilinear',
                                                      align_corners=True)

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if i_iter % args.print_freq == 0:
                print(
                    # f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Iter: [{i_iter}/{max_iter}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')
            if i_iter % args.eval_freq == 0:
                miu = test_miou(net, tgt_val_loader, upsample,
                                './dataset/info.json')
                if miu > highest_miu:
                    torch.save(
                        net.module.state_dict(),
                        osp.join(args.snapshot_dir,
                                 f'{i_iter:d}_{miu*1000:.0f}.pth'))
                    highest_miu = miu
                print(f'>>>>>>>>>Learning Rate {lr}<<<<<<<<<')
            if i_iter == max_iter:
                return