コード例 #1
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        logger.info("adopting Deeplabv2 base model..")
        model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.model == "FCN8S":
        logger.info("adopting FCN8S base model..")
        from pytorchgo.model.MyFCN8s import MyFCN8s
        model = MyFCN8s(n_class=NUM_CLASSES)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    else:
        raise ValueError

    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda()

    model_D2.train()
    model_D2.cuda()

    if SOURCE_DATA == "GTA5":
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    elif SOURCE_DATA == "SYNTHIA":
        trainloader = data.DataLoader(SynthiaDataSet(
            args.data_dir,
            args.data_list,
            LABEL_LIST_PATH,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    else:
        raise ValueError

    targetloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    best_mIoU = 0

    model_summary([model, model_D1, model_D2])
    optimizer_summary([optimizer, optimizer_D1, optimizer_D2])

    for i_iter in tqdm(range(args.num_steps_stop),
                       total=args.num_steps_stop,
                       desc="training"):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter)
        lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ######################### train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda()

            pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = loss_calc(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _, _ = batch
            images = Variable(images).cuda()

            pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            ################################## train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2
            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 100 == 0:
            logger.info(
                'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}'
                .format(i_iter, args.num_steps_stop, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2,
                        lr, lr_D1, best_mIoU))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            logger.info("saving snapshot.....")
            cur_miou16 = proceed_test(model, input_size)
            is_best = True if best_mIoU < cur_miou16 else False
            if is_best:
                best_mIoU = cur_miou16
            torch.save(
                {
                    'iteration': i_iter,
                    'optim_state_dict': optimizer.state_dict(),
                    'optim_D1_state_dict': optimizer_D1.state_dict(),
                    'optim_D2_state_dict': optimizer_D2.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'model_D1_state_dict': model_D1.state_dict(),
                    'model_D2_state_dict': model_D2.state_dict(),
                    'best_mean_iu': cur_miou16,
                }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
            if is_best:
                import shutil
                shutil.copy(
                    osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                    osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))

        if i_iter >= args.num_steps_stop - 1:
            break
コード例 #2
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SynthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(CamvidDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print 'save model ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print 'taking snapshot ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
コード例 #3
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    student_net = FCN8s(args.num_classes, args.model_path_prefix)
    student_net = torch.nn.DataParallel(student_net)

    student_net = student_net.cuda()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if 'syn' in args.data_dir:
        src_dataset = SynthiaDataSet(
            args.data_dir, args.data_list, 
            joint_transform=train_joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        src_dataset = Cityscapes16DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    src_loader = data.DataLoader(
        src_dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )

    tgt_val_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform, target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True,
    )

    optimizer = optim.SGD(
        student_net.parameters(), lr=args.learning_rate,
        momentum=args.momentum, weight_decay=args.weight_decay
    )
    # optimizer = optim.Adam(
    #     student_net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    student_params = list(student_net.parameters())

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # src_criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    src_criterion = torch.nn.CrossEntropyLoss(
        ignore_index=255, size_average=False
    )

    num_batches = len(src_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            student_net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time()-tem_time)

            src_output = student_net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = src_criterion(src_output, src_label)
            cls_loss_value /= src_images.shape[0]

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, args.num_classes)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}'
                )

        miu = test_miou(student_net, tgt_val_loader,
                        upsample, './dataset/info_synthia.json', num_classes=args.num_classes)
        if miu > highest:
            torch.save(student_net.module.state_dict(), osp.join(
                args.snapshot_dir, 'final_fcn.pth'))
            print(f'save highest with {miu:.2%}')
            highest = miu
コード例 #4
0
ファイル: train_gen_seg_synthia.py プロジェクト: LcDog/APL
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    src_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
    ])
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    src_dataset = SynthiaDataSet(
        args.data_dir,
        args.data_list,
        joint_transform=joint_transform,
        transform=src_input_transform,
        target_transform=target_transform,
    )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)
    tgt_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform,
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(tgt_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)

    val_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=False)

    style_trans = StyleTrans(args, info_json='./dataset/info_synthia.json')
    style_trans.train(src_loader, tgt_loader, val_loader, writer)

    writer.close()
コード例 #5
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    args.input_size = (w, h)

    w, h = map(int, args.crop_size.split(','))
    args.crop_size = (h, w)

    w, h = map(int, args.input_size_target.split(','))
    args.input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    str_ids = args.gpu_ids.split(',')
    gpu_ids = []
    for str_id in str_ids:
        gid = int(str_id)
        if gid >= 0:
            gpu_ids.append(gid)

    num_gpu = len(gpu_ids)
    args.multi_gpu = False
    if num_gpu > 1:
        args.multi_gpu = True
        Trainer = AD_Trainer(args)
        Trainer.G = torch.nn.DataParallel(Trainer.G, gpu_ids)
        Trainer.D1 = torch.nn.DataParallel(Trainer.D1, gpu_ids)
        Trainer.D2 = torch.nn.DataParallel(Trainer.D2, gpu_ids)
    else:
        Trainer = AD_Trainer(args)

    print(Trainer)

    trainloader = data.DataLoader(SynthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        resize_size=args.input_size,
        crop_size=args.crop_size,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        autoaug=args.autoaug),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  drop_last=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        resize_size=args.input_size_target,
        crop_size=args.crop_size,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set,
        autoaug=args.autoaug_target),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   drop_last=True)

    targetloader_iter = enumerate(targetloader)

    # set up tensor board
    if args.tensorboard:
        args.log_dir += '/' + os.path.basename(args.snapshot_dir)
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        adjust_learning_rate(Trainer.gen_opt, i_iter, args)
        adjust_learning_rate_D(Trainer.dis1_opt, i_iter, args)
        adjust_learning_rate_D(Trainer.dis2_opt, i_iter, args)

        for sub_i in range(args.iter_size):

            # train G

            # train with source

            _, batch = trainloader_iter.__next__()
            _, batch_t = targetloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.cuda()
            labels = labels.long().cuda()
            images_t, labels_t, _, _ = batch_t
            images_t = images_t.cuda()
            labels_t = labels_t.long().cuda()

            with Timer("Elapsed time in update: %f"):
                loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss = Trainer.gen_update(
                    images, images_t, labels, labels_t, i_iter)
                loss_seg_value1 += loss_seg1.item() / args.iter_size
                loss_seg_value2 += loss_seg2.item() / args.iter_size
                loss_adv_target_value1 += loss_adv_target1 / args.iter_size
                loss_adv_target_value2 += loss_adv_target2 / args.iter_size
                loss_me_value = loss_me

                if args.lambda_adv_target1 > 0 and args.lambda_adv_target2 > 0:
                    loss_D1, loss_D2 = Trainer.dis_update(
                        pred1, pred2, pred_target1, pred_target2)
                    loss_D_value1 += loss_D1.item()
                    loss_D_value2 += loss_D2.item()
                else:
                    loss_D_value1 = 0
                    loss_D_value2 = 0

        del pred1, pred2, pred_target1, pred_target2

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_me_target': loss_me_value,
                'loss_kl_target': loss_kl,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
                'val_loss': val_loss,
            }

            if i_iter % 100 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            '\033[1m iter = %8d/%8d \033[0m loss_seg1 = %.3f loss_seg2 = %.3f loss_me = %.3f  loss_kl = %.3f loss_adv1 = %.3f, loss_adv2 = %.3f loss_D1 = %.3f loss_D2 = %.3f, val_loss=%.3f'
            % (i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
               loss_me_value, loss_kl, loss_adv_target_value1,
               loss_adv_target_value2, loss_D_value1, loss_D_value2, val_loss))

        # clear loss
        del loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, val_loss

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                Trainer.G.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                Trainer.D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                Trainer.D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                Trainer.G.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                Trainer.D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                Trainer.D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()