def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        logger.info("adopting Deeplabv2 base model..")
        model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.model == "FCN8S":
        logger.info("adopting FCN8S base model..")
        from pytorchgo.model.MyFCN8s import MyFCN8s
        model = MyFCN8s(n_class=NUM_CLASSES)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    else:
        raise ValueError

    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda()

    model_D2.train()
    model_D2.cuda()

    if SOURCE_DATA == "GTA5":
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    elif SOURCE_DATA == "SYNTHIA":
        trainloader = data.DataLoader(SynthiaDataSet(
            args.data_dir,
            args.data_list,
            LABEL_LIST_PATH,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    else:
        raise ValueError

    targetloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    best_mIoU = 0

    model_summary([model, model_D1, model_D2])
    optimizer_summary([optimizer, optimizer_D1, optimizer_D2])

    for i_iter in tqdm(range(args.num_steps_stop),
                       total=args.num_steps_stop,
                       desc="training"):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter)
        lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ######################### train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda()

            pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = loss_calc(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _, _ = batch
            images = Variable(images).cuda()

            pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            ################################## train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2
            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 100 == 0:
            logger.info(
                'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}'
                .format(i_iter, args.num_steps_stop, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2,
                        lr, lr_D1, best_mIoU))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            logger.info("saving snapshot.....")
            cur_miou16 = proceed_test(model, input_size)
            is_best = True if best_mIoU < cur_miou16 else False
            if is_best:
                best_mIoU = cur_miou16
            torch.save(
                {
                    'iteration': i_iter,
                    'optim_state_dict': optimizer.state_dict(),
                    'optim_D1_state_dict': optimizer_D1.state_dict(),
                    'optim_D2_state_dict': optimizer_D2.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'model_D1_state_dict': model_D1.state_dict(),
                    'model_D2_state_dict': model_D2.state_dict(),
                    'best_mean_iu': cur_miou16,
                }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
            if is_best:
                import shutil
                shutil.copy(
                    osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                    osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))

        if i_iter >= args.num_steps_stop - 1:
            break
Пример #2
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        model.load_state_dict(saved_state_dict)
        # ====the lines below comment by ZQ====
        # new_params = model.state_dict().copy()
        # for i in saved_state_dict:
        #     # Scale.layer5.conv2d_list.3.weight
        #     i_parts = i.split('.')
        #     # print i_parts
        #     if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #         new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #         # print i_parts
        # model.load_state_dict(new_params)
        # ====end====

    model.train()

    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(MixedDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
    #                                pin_memory=True)
    #
    #
    # targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
    optimizer.module.zero_grad()

    # optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    # optimizer_D1.zero_grad()
    #
    # optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    # optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    # source_label = 0
    # target_label = 1

    loss2_sum_per_epoch = 0
    loss2_per_epoch = 0
    epoch = 0
    loss2_epoch = ''
    lbl_list = open(args.data_list, 'r')
    lbl_num = len(lbl_list.readlines()) / 2
    # lbl_list = os.listdir(args.data_dir+'/labels')
    # lbl_num = len(lbl_list)/2
    mIoUs = []
    i_iters = []

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        # loss_adv_target_value1 = 0
        # loss_D_value1 = 0

        loss_seg_value2 = 0
        # loss_adv_target_value2 = 0
        # loss_D_value2 = 0

        optimizer.module.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):
            # train G

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, img_nam = batch
            # print('%s\n' % img_nam[0])
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target
        optimizer.module.step()

        # print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter + strat_snap_iter) + '.pth'))
            show_pred_sv_dir = pre_sv_dir.format(i_iter)
            mIoU = show_val(model.state_dict(), show_pred_sv_dir, gpu)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            i_iters.append(i_iter)
            print_i = 0
            for miou in mIoUs:
                print('i{0}: {1}'.format(i_iters[print_i], miou))
                print_i = print_i + 1

        loss2_sum_per_epoch += loss_seg2.data.cpu().numpy()
        if i_iter % lbl_num == 0 and i_iter != 0:
            epoch += 1
            loss2_per_epoch = loss2_sum_per_epoch / lbl_num
            loss2_epoch += 'epoch = {0}, loss_seg2 = {1:.3f} \n'.format(
                epoch, loss2_per_epoch)
            print(loss2_epoch)
            loss2_sum_per_epoch = 0