Beispiel #1
0
def __main__(args):
    #initializing pretrained network
    pspnet = PSPNet(n_classes=cityscapes.num_classes).cuda(gpu0)
    pspnet.load_pretrained_model(model_path=pspnet_path)
    #transformation and loading dataset
    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = standard_transforms.Compose(
        [extended_transforms.MaskToTensor()])

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])

    visualize = standard_transforms.ToTensor()
    val_set = cityscapes.CityScapes('val',
                                    transform=val_input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False)
    validate(pspnet, val_loader, cityscapes.num_classes, args,
             restore_transform, visualize)
def main():
    net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                               'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                               'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform,
                                      transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.999))

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Beispiel #3
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    student_net = FCN8s(args.num_classes, args.model_path_prefix)
    student_net = torch.nn.DataParallel(student_net)

    student_net = student_net.cuda()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    optimizer = optim.SGD(student_net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     student_net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    student_params = list(student_net.parameters())

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # src_criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    src_criterion = torch.nn.CrossEntropyLoss(ignore_index=255,
                                              size_average=False)

    num_batches = len(src_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            student_net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = student_net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = src_criterion(src_output, src_label)
            cls_loss_value /= src_images.shape[0]

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')

        miu = test_miou(student_net, tgt_val_loader, upsample,
                        './dataset/info.json')
        if miu > highest:
            torch.save(student_net.module.state_dict(),
                       osp.join(args.snapshot_dir, f'final_fcn.pth'))
            highest = miu
            print('>' * 50 + f'save highest with {miu:.2%}')
Beispiel #4
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if args.seg_net == 'fcn':
        mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])
        val_input_transform = standard_transforms.Compose([
            extended_transforms.FreeScale((h, w)),
            extended_transforms.FlipChannels(),
            standard_transforms.ToTensor(),
            standard_transforms.Lambda(lambda x: x.mul_(255)),
            standard_transforms.Normalize(*mean_std),
        ])
    else:
        normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        val_input_transform = standard_transforms.Compose([
            extended_transforms.FreeScale((h, w)),
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*normalize),
        ])

    tgt_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform,
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(tgt_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)
    val_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    if args.seg_net == 'fcn':
        net = FCN8s(args.n_classes, pretrained=False)
        net_static = FCN8s(args.n_classes, pretrained=False)
        file_name = os.path.join(args.resume, args.fcn_name)
        # for name, param in net.named_parameters():
        #     if 'feat' not in name:
        #         param.requires_grad = False
    elif args.seg_net == 'deeplab_ibn':
        deeplab = resnet101_ibn_a_deeplab()
        file_name = os.path.join(args.resume, 'deeplab_ibn.pth')
    net.load_state_dict(torch.load(file_name))
    net_static.load_state_dict(torch.load(file_name))
    for param in net_static.parameters():
        param.requires_grad = False

    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net.cuda())
    net_static = torch.nn.DataParallel(net_static.cuda())
    # criterion = torch.nn.MSELoss()
    # criterion = torch.nn.SmoothL1Loss()
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    gen_model = define_G()
    gen_model.load_state_dict(
        torch.load(os.path.join(args.resume, args.gen_name)))
    gen_model.eval()
    for param in gen_model.parameters():
        param.requires_grad = False
    gen_model = torch.nn.DataParallel(gen_model.cuda())

    # for seg net
    def normalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        if args.seg_net == 'fcn':
            mean = [103.939, 116.779, 123.68]
            flip_x = torch.cat(
                [x[:, 2 - i, :, :].unsqueeze(1) for i in range(3)],
                dim=1,
            )
            new_x = []
            for tem_x in flip_x:
                tem_new_x = []
                for c, m in zip(tem_x, mean):
                    tem_new_x.append(c.mul(255.0).sub(m).unsqueeze(0))
                new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0))
            new_x = torch.cat(new_x, dim=0)
            return new_x
        else:
            for tem_x in x:
                for c, m, s in zip(tem_x, mean, std):
                    c = c.sub(m).div(s)
            return x

    def de_normalize(x, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        new_x = []
        for tem_x in x:
            tem_new_x = []
            for c, m, s in zip(tem_x, mean, std):
                tem_new_x.append(c.mul(s).add(s).unsqueeze(0))
            new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0))
        new_x = torch.cat(new_x, dim=0)
        return new_x

    # ###################################################
    # direct test with gen
    # ###################################################
    print('Direct Test')
    mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json')
    direct_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_dataset_direct = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=direct_input_transform,
        target_transform=target_transform,
    )
    val_loader_direct = data.DataLoader(val_dataset_direct,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        drop_last=False)

    class NewModel(object):
        def __init__(self, gen_net, val_net):
            self.gen_net = gen_net
            self.val_net = val_net

        def __call__(self, x):
            x = de_normalize(self.gen_net(x))
            new_x = normalize(x)
            out = self.val_net(new_x)
            return out

        def eval(self):
            self.gen_net.eval()
            self.val_net.eval()

    new_model = NewModel(gen_model, net)
    print('Test with Gen')
    mean_iu = test_miou(new_model, val_loader_direct, upsample,
                        './dataset/info.json')
    # return

    num_batches = len(tgt_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(tgt_loader):
            iteration = batch_index + 1 + epoch * num_batches

            net.train()
            net_static.eval()  # fine-tune use eval

            img, _, name = batch_data
            img = img.cuda()
            data_time_rec.update(time.time() - tem_time)

            with torch.no_grad():
                gen_output = gen_model(img)
                gen_seg_output_logits = net_static(
                    normalize(de_normalize(gen_output)))
            ori_seg_output_logits = net(normalize(de_normalize(img)))

            prob = torch.nn.Softmax(dim=1)
            max_value, label = torch.max(prob(gen_seg_output_logits), dim=1)
            label_mask = torch.zeros(label.shape, dtype=torch.uint8).cuda()
            for tem_label in range(19):
                tem_mask = label == tem_label
                if torch.sum(tem_mask) < 5:
                    continue
                value_vec = max_value[tem_mask]
                large_value = torch.topk(
                    value_vec, int(args.percent * value_vec.shape[0]))[0][0]
                large_mask = max_value > large_value
                label_mask = label_mask | (tem_mask & large_mask)
            label[label_mask] = 255

            # loss = criterion(ori_seg_output_logits, gen_seg_output_logits)
            loss = criterion(ori_seg_output_logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if iteration % args.checkpoint_freq == 0:
                mean_iu = test_miou(net,
                                    val_loader,
                                    upsample,
                                    './dataset/info.json',
                                    print_results=False)
                if mean_iu > highest:
                    torch.save(
                        net.module.state_dict(),
                        os.path.join(args.save_path_prefix,
                                     'cityscapes_best_fcn.pth'))
                    highest = mean_iu
                    print(f'save fcn model with {mean_iu:.2%}')

    print(('-' * 100 + '\n') * 3)
    print('>' * 50 + 'Final Model')
    net.module.load_state_dict(
        torch.load(
            os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth')))
    mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json')

    writer.close()
Beispiel #5
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    if args.bn_sync:
        print('Using Sync BN')
        deeplabv3.BatchNorm2d = partial(InPlaceABNSync, activation='none')
    net = get_deeplabV3(args.num_classes, args.model_path_prefix)
    if not args.bn_sync:
        net.freeze_bn()
    net = torch.nn.DataParallel(net)

    net = net.cuda()

    mean_std = ([104.00698793, 116.66876762, 122.67891434], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    # freeze bn
    for module in net.module.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = False
    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad,
                             net.module.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, size_average=True
    # )
    criterion = CriterionDSN(ignore_index=255,
                             # size_average=False
                             )

    num_batches = len(src_loader)
    max_iter = args.iterations
    i_iter = 0
    highest_miu = 0

    while True:

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            i_iter += 1
            lr = adjust_learning_rate(args, optimizer, i_iter, max_iter)
            net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = criterion(src_output, src_label)

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            src_output = torch.nn.functional.upsample(input=src_output[0],
                                                      size=(h, w),
                                                      mode='bilinear',
                                                      align_corners=True)

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if i_iter % args.print_freq == 0:
                print(
                    # f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Iter: [{i_iter}/{max_iter}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')
            if i_iter % args.eval_freq == 0:
                miu = test_miou(net, tgt_val_loader, upsample,
                                './dataset/info.json')
                if miu > highest_miu:
                    torch.save(
                        net.module.state_dict(),
                        osp.join(args.snapshot_dir,
                                 f'{i_iter:d}_{miu*1000:.0f}.pth'))
                    highest_miu = miu
                print(f'>>>>>>>>>Learning Rate {lr}<<<<<<<<<')
            if i_iter == max_iter:
                return
def main():
    net = PSPNet(19)
    net.load_pretrained_model(
        model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel')
    for param in net.parameters():
        param.requires_grad = False
    net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False)
    net.dropout = nn.Dropout2d(p=0.1, inplace=True)
    net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0)

    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters()
                                 if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        args['best_record'] = {
            'epoch': 0,
            'iter': 0,
            'val_loss': 1e10,
            'accu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'iter': int(split_snapshot[3]),
            'val_loss': float(split_snapshot[5]),
            'accu': float(split_snapshot[7])
        }
    net.cuda(args['gpu']).train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = kitti_binary.KITTI(mode='train',
                                   joint_transform=train_joint_transform,
                                   transform=train_input_transform,
                                   target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = kitti_binary.KITTI(mode='val',
                                 transform=val_input_transform,
                                 target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=8,
                            shuffle=False)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda(
        args['gpu'])

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, args, val_loader)