示例#1
0
def main(args):
    """Create the model and start the evaluation process."""

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(ListDataSet(args.data_dir,
                                             args.img_list,
                                             args.lbl_list,
                                             crop_size=(1024, 512),
                                             mean=IMG_MEAN,
                                             split=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, _, name = batch
            if args.model == 'DeeplabMulti':
                output1, output2 = model(Variable(image).cuda(gpu0))
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (args.save, name))
            output_col.save('%s/%s_color.png' %
                            (args.save, name.split('.')[0]))
示例#2
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=input_size, mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        if args.model == 'DeeplabMulti':
            output1, output2,_,_ = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
示例#3
0
def main():
    """Create the model and start the training."""
    setup_seed(666)
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            if i_parts[1]=='layer4' and i_parts[2]=='2':
                i_parts[1] = 'layer5'
                i_parts[2] = '0'
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            else:
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    model.load_state_dict(new_params)
    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(device) if i<1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in range(2)])
    
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)

        loss_adv = 0
        D_out = model_D[0](feat_target)
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        D_out = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv*0.01
        loss_adv.backward()
        
        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        loss_D_source = 0
        D_out_source = model_D[0](feat_source.detach())
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        D_out_source = model_D[1](F.softmax(pred_source.detach(),dim=1))
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        loss_D_source.backward()

        # train with target
        loss_D_target = 0
        D_out_target = model_D[0](feat_target.detach())
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        D_out_target = model_D[1](F.softmax(pred_target.detach(),dim=1))
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        loss_D_target.backward()
        
        optimizer_D.step()

        if i_iter % 10 == 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D_s = {4:.3f}, loss_D_t = {5:.3f}'.format(
            i_iter, args.num_steps, loss_seg.item(), loss_adv.item(), loss_D_source.item(), loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Implemented by Bongjoon Hyun
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)
    #

    # Implemented by Bongjoon Hyun
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    #

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        # Implemented by Bongjoon Hyun
        bce_loss = torch.nn.MSELoss()
        #

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):
            # Implemented by Bongjoon Hyun
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = (loss_seg2 + args.lambda_seg * loss_seg1) / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_adv_target1 = bce_loss(D1_out, labels_source1)
            loss_adv_target2 = bce_loss(D2_out, labels_source2)

            loss = args.lambda_adv_target1 * loss_adv_target1 + \
                   args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size

            loss.backward()

            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D1_out = model_D1(F.softmax(pred1))
            D2_out = model_D2(F.softmax(pred2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_source1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_source2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_target1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(target_label)).cuda(args.gpu)
            labels_target2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(target_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_target1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_target2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()
            #

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print
            'save model ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print
            'taking snapshot ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
示例#5
0
def main():
    '''
        Create the model and start the training.
    '''

    # Device 설정
    device = torch.device("cuda" if not args.cpu else "cpu")

    # Source와 Target 모두 1280 * 720으로 resizing
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)
    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # 모델 생성
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :                        # 미리 학습된 weight를 다운로드
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:                                                       # pth 파일을 직접 설정할 경우
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # Discriminator 생성
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _, domainess = batch
            adw = torch.sqrt(1-domainess).float()
            adw.requires_grad = False
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))


            loss_adv_target1 = adw*bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = adw*bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
示例#6
0
def main():
    """Create the model and start the training."""
    global args
    args = get_arguments()
    if args.dist:
        init_dist(args.launcher, backend=args.backend)
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'Deeplab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from, strict=False)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict, strict=False)
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict, strict=False)
            del saved_state_dict

    model.train()
    model.to(device)
    if args.dist:
        broadcast_params(model)

    if rank == 0:
        print(model)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)
    if args.dist:
        broadcast_params(model_D1)
    if args.restore_D is not None:
        D_dict = torch.load(args.restore_D)
        model_D1.load_state_dict(D_dict, strict=False)
        del D_dict

    model_D2.train()
    model_D2.to(device)
    if args.dist:
        broadcast_params(model_D2)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_data = GTA5BDDDataSet(args.data_dir,
                                args.data_list,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size,
                                scale=args.random_scale,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)
    trainloader = data.DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False if train_sampler else True,
                                  num_workers=args.num_workers,
                                  pin_memory=False,
                                  sampler=train_sampler)

    trainloader_iter = enumerate(cycle(trainloader))

    target_data = BDDDataSet(args.data_dir_target,
                             args.data_list_target,
                             max_iters=args.num_steps * args.iter_size *
                             args.batch_size,
                             crop_size=input_size_target,
                             scale=False,
                             mirror=args.random_mirror,
                             mean=IMG_MEAN,
                             set=args.set)
    target_sampler = None
    if args.dist:
        target_sampler = DistributedSampler(target_data)
    targetloader = data.DataLoader(target_data,
                                   batch_size=args.batch_size,
                                   shuffle=False if target_sampler else True,
                                   num_workers=args.num_workers,
                                   pin_memory=False,
                                   sampler=target_sampler)

    targetloader_iter = enumerate(cycle(targetloader))

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard and rank == 0:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    torch.cuda.empty_cache()
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, size, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            interp = nn.Upsample(size=(size[1], size[0]),
                                 mode='bilinear',
                                 align_corners=True)

            pred1 = model(images)
            pred1 = interp(pred1)

            loss_seg1 = seg_loss(pred1, labels)

            loss = loss_seg1

            # proper normalization
            loss = loss / args.iter_size / world_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size

            _, batch = targetloader_iter.__next__()
            # train with target
            images, _, _ = batch
            images = images.to(device)

            pred_target1 = model(images)
            pred_target1 = interp_target(pred_target1)

            D_out1 = model_D1(F.softmax(pred_target1))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1
            loss = loss / args.iter_size / world_size

            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            D_out1 = model_D1(F.softmax(pred1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with target
            pred_target1 = pred_target1.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            if args.dist:
                average_gradients(model)
                average_gradients(model_D1)
                average_gradients(model_D2)

            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        if rank == 0:
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1 * world_size,
                    'loss_D2': loss_D_value2 * world_size,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D1.pth'))
    print(args.snapshot_dir)
    if args.tensorboard and rank == 0:
        writer.close()
示例#7
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            continue

        optimizer.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_noadapt_' + str(args.num_steps_stop) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
示例#8
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #### restore model_D1, D2 and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D1 parameters
        restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D1)
        model_D1.load_state_dict(saved_state_dict)

        # model_D2 parameters
        restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D2)
        model_D2.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            pdb.set_trace()
            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)
            pdb.set_trace()
            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            pdb.set_trace()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'loss_adv_target1': loss_adv_target_value1,
            'loss_adv_target2': loss_adv_target_value2,
            'loss_D1': loss_D_value1,
            'loss_D2': loss_D_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
示例#9
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = 2
    torch.cuda.manual_seed(1337)
    torch.cuda.set_device(2)

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    for i in range(5):
        if not os.path.exists(args.save + '/' + str(i)):
            os.makedirs(args.save + '/' + str(i))

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    print("begin")

    if args.restore_from[:4] == 'http':
        print("1112222")
        #saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        print("2222222", gpu0)
        # saved_state_dict = torch.load(args.restore_from)
        print(args.restore_from)
        model.load_state_dict(torch.load(args.restore_from))
    model.cuda(gpu0)
    # print(sys.getsizeof(model))
    # model.eval()
    # exit()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        # interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
        interp = Upsample_function

    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')
    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, labels, _, name = batch
            image = Variable(image).cuda(gpu0)
            final = []
            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output1 = F.softmax(output1, 1)
                output2 = F.softmax(output2, 1)
                for i in [0, 3, 7, 10]:
                    final_output = i / 10.0 * output1 + (10.0 -
                                                         i) / 10.0 * output2
                    output = interp(final_output).cpu().data[0].numpy()
                    final.append(output)
                    break
                labels = labels.cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image, volatile=True).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            name = name[0].split('/')[-1]
            # labels_col = colorize_mask(labels)
            # labels_col.save('%s/%s_real.png' % (args.save, name.split('.')[0]))
            for i in range(4):
                output = final[i]
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                output.save('%s/%s/%s' % (args.save, str(i), name))
                output_col.save('%s/%s/%s_color.png' %
                                (args.save, str(i), name.split('.')[0]))
                break
def main(args):
    """Create the model and start the training."""

    mkdir_check(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.snapshot_dir, 'tb-logs'))

    h, w = map(int, args.src_input_size.split(','))
    src_input_size = (h, w)

    h, w = map(int, args.tgt_input_size.split(','))
    tgt_input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=19)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
            ListDataSet(args.src_data_dir, args.src_img_list, args.src_lbl_list, 
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=src_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_list, args.tgt_lbl_list,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_nolabel = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_nolabel_list, None,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    targetloader_nolabel_iter = enumerate(targetloader_nolabel)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(src_input_size[1], src_input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(tgt_input_size[1], tgt_input_size[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_tgt_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(args, optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(args, optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            #pred1 = interp(pred1)
            pred2 = interp(pred2)

            #loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 #+ args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target seg

            _, batch = targetloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            #loss_tgt_seg1 = loss_calc(pred_target1, labels, args.gpu)
            loss_tgt_seg2 = loss_calc(pred_target2, labels, args.gpu)
            loss = loss_tgt_seg2 #+ args.lambda_seg * loss_tgt_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward(retain_graph=True)
            loss_tgt_seg_value2 += loss_tgt_seg2.data.cpu().numpy() / args.iter_size

            # train with target_nolabel adv
            _, batch = targetloader_nolabel_iter.__next__()
            images, _, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()

            D_out2 = model_D2(F.softmax(pred2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D2.step()

        print(
        'iter = {:5d}/{:8d}, loss_seg2 = {:.3f} loss_tgt_seg2 = {:.3f} loss_adv2 = {:.3f} loss_D2 = {:.3f}'.format(
            i_iter, args.num_steps_stop, loss_seg_value2, loss_tgt_seg_value2,
            loss_adv_target_value2, loss_D_value2))

        writer.add_scalars('loss/seg', {
            'src2': loss_seg_value2, 
            'tgt2': loss_tgt_seg_value2,
            }, i_iter)

        writer.add_scalar('loss/adv', loss_adv_target_value2, i_iter)
        writer.add_scalar('loss/d', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
def main():
    """Create the model and start the training."""

    gpu_id_2 = 1
    gpu_id_1 = 0

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(
                '/data2/zhangjunyi/snapshots/snapshots_syn/onlysyn/GTA5_80000.pth',
                map_location={"cuda:3": "cuda:0"})

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if (not i_parts[1] == 'layer5') and (not i_parts[0] == 'fc'):
                new_params['.'.join(i_parts)] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes, match=False)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SynthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.__next__()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        cut_size=cut_size,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.__next__()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    mse_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        number1 = 0
        number2 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            beta = args.beta
            if i_iter == 0:
                print(beta)
            _, batch = trainloader_iter.__next__()
            _, batch_target = targetloader_iter.__next__()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number1 = torch.sum(labels == 255).item()
            loss_seg2, new_labels = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number2 = torch.sum(labels == 255).item()
            loss_seg2, new_lables = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)
            # exit()

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D1(fake2_D)

            loss_adv_fake1 = mse_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = mse_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D1(mix2_D)

            loss_adv_mix1 = mse_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = mse_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value1 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D2
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            D_out_target_1 = model_D2(pred1_target_D)
            D_out_target_2 = model_D2(pred2_target_D)

            loss_adv_target1 = bce_loss(
                D_out_target_1,
                Variable(
                    torch.FloatTensor(D_out_target_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target2 = bce_loss(
                D_out_target_2,
                Variable(
                    torch.FloatTensor(D_out_target_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D1(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D1(mix2_D_)

            loss_D1 = mse_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = mse_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = mse_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = mse_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value1 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_last_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_last_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D1(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D1(mix2_D__)

            loss_D1 = mse_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = mse_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = mse_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = mse_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value1 += loss_D2.data.cpu().numpy()

            # train model-D2
            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            D_out1 = model_D2(pred1_D)
            D_out2 = model_D2(pred2_D)
            D_out3 = model_D2(pred1_target_D)
            D_out4 = model_D2(pred2_target_D)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value2 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, number1 = {8}, number2 = {9}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, number1, number2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
示例#12
0
def main():
    seed = 1338
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        if SOURCE_ONLY:
            saved_state_dict = torch.load('./snapshots/source_only/GTA5_' +
                                          str((files + 1) *
                                              args.save_pred_every) + '.pth')
        else:
            if args.level == 'single-level':
                saved_state_dict = torch.load(
                    './snapshots/single_level/GTA5_' +
                    str((files + 1) * args.save_pred_every) + '.pth')
            elif args.level == 'multi-level':
                saved_state_dict = torch.load('./snapshots/multi_level/GTA5_' +
                                              str((files + 1) *
                                                  args.save_pred_every) +
                                              '.pth')
            else:
                raise NotImplementedError(
                    'level choice {} is not implemented'.format(args.level))
        ### for running different versions of pytorch
        model_dict = model.state_dict()
        saved_state_dict = {
            k: v
            for k, v in saved_state_dict.items() if k in model_dict
        }
        model_dict.update(saved_state_dict)
        ###
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if not args.cpu else "cpu")
        model = model.to(device)
        if args.multi_gpu:
            model = nn.DataParallel(model)

        model.eval()

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch
            image = image.to(device)

            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
                output = model(image)
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            if SOURCE_ONLY:
                if not os.path.exists(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every))):
                    os.makedirs(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every)))
                output.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every), name))
                output_col.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every),
                        name.split('.')[0] + '_color.png'))
            else:
                if args.level == 'single-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                elif args.level == 'multi-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                else:
                    raise NotImplementedError(
                        'level choice {} is not implemented'.format(
                            args.level))
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D parameters
        restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D)
        model_D.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    """
    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        adv_loss_value = 0
        d_loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate(optimizer_D, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            """
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False
            """

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # size (1, 19, 720, 1280)
            # pdb.set_trace()

            # feature = nn.Softmax(dim=1)(pred1)
            # softmax_out = nn.Softmax(dim=1)(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            _, batch = targetloader_iter.__next__()

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=False)

            images, _, _ = batch
            images = images.to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred_target1, pred_target2 = model(images)

            # pred_target1, 2 == [1, 19, 91, 161]
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            # pred_target1, 2 == [1, 19, 720, 1280]
            # pdb.set_trace()

            # feature_target = nn.Softmax(dim=1)(pred_target1)
            # softmax_out_target = nn.Softmax(dim=1)(pred_target2)

            # features = torch.cat((pred1, pred_target1), dim=0)
            # outputs = torch.cat((pred2, pred_target2), dim=0)
            # features.size() == [2, 19, 720, 1280]
            # softmax_out.size() == [2, 19, 720, 1280]
            # pdb.set_trace()
            # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None)
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')
            dc_source = torch.FloatTensor(
                D_out_target.size()).fill_(0).to(device)
            # pdb.set_trace()
            adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source)
            adv_loss = adv_loss / args.iter_size
            adv_loss = args.lambda_adv * adv_loss
            # pdb.set_trace()
            # classifier_loss = nn.BCEWithLogitsLoss()(pred2,
            #        torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda())
            # pdb.set_trace()
            adv_loss.backward()
            adv_loss_value += adv_loss.item()
            # optimizer_D.step()
            #TODO: normalize loss?

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=True)

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            D_out = CDAN([F.softmax(pred1), F.softmax(pred2)],
                         model_D,
                         cdan_implement='concat')

            dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device)
            # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None)
            d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')

            dc_target = torch.FloatTensor(
                D_out_target.size()).fill_(1).to(device)
            d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            continue

        optimizer.step()
        optimizer_D.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'generator_loss': adv_loss_value,
            'discriminator_loss': d_loss_value,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    adv_loss_value, d_loss_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)
            save_path = args.snapshot_dir + '/eval_{}'.format(i_iter)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            # evaluate(args, save_path, args.snapshot_dir, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
                i_iter)
            pdb.set_trace()
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
示例#14
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    confidence_path = os.path.join(args.save, 'submit/confidence')
    label_path = os.path.join(args.save, 'submit/labelTrainIds')
    label_invalid_path = os.path.join(args.save,
                                      'submit/labelTrainIds_invalid')
    for path in [confidence_path, label_path, label_invalid_path]:
        if not os.path.exists(path):
            os.makedirs(path)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(DarkZurichDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(h, w),
                                                   resize_size=(w, h),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(DarkZurichDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(h * scale), round(w * scale)),
        resize_size=(round(w * scale), round(h * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1080, 1920),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1080, 1920), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')
    prior = np.load('./utils/prior_all.npy').transpose(
        (2, 0, 1))[np.newaxis, :, :, :]
    prior = torch.from_numpy(prior)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, name = batch
        image2, _, name2 = batch2

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, args.batchsize * len(testloader)),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))

                heatmap_batch = torch.sum(kl_distance(log_sm(output1),
                                                      sm(output2)),
                                          dim=1)

                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                ratio = 0.95
                output_batch = output_batch.cpu() / 4
                # output_batch = output_batch *(ratio + (1 - ratio) * prior)
                output_batch = output_batch.data.numpy()
                heatmap_batch = heatmap_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        threshold = 0.3274
        for i in range(output_batch.shape[0]):
            output_single = output_batch[i, :, :]
            output_col = colorize_mask(output_single)
            output = Image.fromarray(output_single)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

            # heatmap_tmp = heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:])
            # fig = plt.figure()
            # plt.axis('off')
            # heatmap = plt.imshow(heatmap_tmp, cmap='viridis')
            # fig.colorbar(heatmap)
            # fig.savefig('%s/%s_heatmap.png' % (save_path, name_tmp.split('.')[0]))

            if args.set == 'test' or args.set == 'val':
                # label
                output.save('%s/%s' % (label_path, name_tmp))
                # label invalid
                output_single[score_batch[i, :, :] < threshold] = 255
                output = Image.fromarray(output_single)
                output.save('%s/%s' % (label_invalid_path, name_tmp))
                # conficence

                confidence = score_batch[i, :, :] * 65535
                confidence = np.asarray(confidence, dtype=np.uint16)
                print(confidence.min(), confidence.max())
                iio.imwrite('%s/%s' % (confidence_path, name_tmp), confidence)

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    print(model)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(BDDDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(960, 540),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False,
                                            set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)  # 960 540

    interp = nn.Upsample(size=(720, 1280), mode='bilinear', align_corners=True)

    if args.save_confidence:
        select = open('list.txt', 'w')
        c_list = []

    for index, batch in enumerate(testloader):
        if index % 10 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        output = model(image)

        if args.save_confidence:
            confidence = get_confidence(output)
            confidence = confidence.cpu().item()
            c_list.append([confidence, name])
            name = name[0].split('/')[-1]
            save_path = '%s/%s_c.txt' % (args.save, name.split('.')[0])
            record = open(save_path, 'w')
            record.write('%.5f' % confidence)
            record.close()
        else:
            name = name[0].split('/')[-1]

        output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        output.save('%s/%s' % (args.save, name[:-4] + '.png'))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    def takeFirst(elem):
        return elem[0]

    if args.save_confidence:
        c_list.sort(key=takeFirst, reverse=True)
        length = len(c_list)
        for i in range(length // 3):
            print(c_list[i][0])
            print(c_list[i][1])
            select.write(c_list[i][1][0])
            select.write('\n')
        select.close()

    print(args.save)
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(robotDataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(960, 1280),
                                              resize_size=(1280, 960),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False,
                                              set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(robotDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(960 * scale), round(1280 * scale)),
        resize_size=(round(1280 * scale), round(960 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(960, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(960, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        print(image.shape)

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        #output_batch[score_batch<3.6] = 255  #3.6 = 4*0.9

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo')
            #print(save_path)
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

    return args.save
示例#17
0
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.class_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.often_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate)
            if args.restore_from[:4] == 'http' :
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http' :
                    if i_parts[1] !='fc' and i_parts[1] !='layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] =='module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim = args.num_classes).cuda() 
        self.D2 = MsImageDis(input_dim = args.num_classes).cuda() 
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim = 1)
        self.log_sm = torch.nn.LogSoftmax(dim = 1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True)
        self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1")
示例#18
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    log_dir = args.save
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    # testloader = data.DataLoader(SyntheticSmokeTrain(args={}, dataset_limit=-1, #args.num_steps * args.iter_size * args.batch_size,
    #                 image_shape=(360,640), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)

    testloader = data.DataLoader(SmokeDataset(image_size=(640, 360),
                                              dataset_mean=IMG_MEAN),
                                 batch_size=1,
                                 shuffle=True,
                                 pin_memory=True)
    # testloader = data.DataLoader(SimpleSmokeTrain(args = {}, image_size=(640,360), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)
    # testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
    # batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)

    count = 0
    iou_sum_fg = 0
    iou_count_fg = 0

    iou_sum_bg = 0
    iou_count_bg = 0

    for index, batch in enumerate(testloader):
        if (index + 1) % 100 == 0:
            print('%d processd' % index)
            # print("Processed {}/{}".format(index, len(testloader)))

        # if count > 5:
        #     break
        image, label, name = batch
        if args.model == 'DeeplabMulti':
            with torch.no_grad():
                output1, output2 = model(Variable(image).cuda(gpu0))
            # print(output1.shape)
            # print(output2.shape)
            output = interp(output2).cpu()
            orig_output = output.detach().clone()
            output = output.data[0].numpy()
            # output = (output > 0.5).astype(np.uint8)*255
            # print(np.all(output==0), np.all(output==255))
            # print(np.min(output), np.max(output))

        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            with torch.no_grad():
                output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        classes_seen = set(output.ravel().tolist())
        # print(classes_seen)
        # print(output.shape, name[0])
        output_col = colorize_mask(output)
        output = Image.fromarray(output)
        # print("name", name)
        name = name[0]
        # name = name[0].split('/')[-1]

        if len(classes_seen) > 1:
            count += 1
            print(classes_seen)
            print(Counter(np.asarray(output).ravel()))
            image = image.squeeze()
            for c in range(3):
                image[c, :, :] += IMG_MEAN[c]
                # image2[c,:,:] += IMG_MEAN[2-c]
            image = (image - image.min()) / (image.max() - image.min())
            image = image[[2, 1, 0], :, :]
            print(image.shape, image.min(), image.max())
            output.save(os.path.join(args.save, name + '.png'))
            output_col.save(os.path.join(args.save, name + '_color.png'))
            # output.save('%s/%s.png' % (args.save, name))
            # output_col.save('%s/%s_color.png' % (args.save, name))#.split('.')[0]))

            output_argmaxs = torch.argmax(orig_output.squeeze(), dim=0)
            mask1 = (output_argmaxs == 0).float() * 255
            label = label.squeeze()

            iou_fg = iou_pytorch(mask1, label)
            print("foreground IoU", iou_fg)
            iou_sum_fg += iou_fg
            iou_count_fg += 1

            mask2 = (output_argmaxs > 0).float() * 255
            label2 = label.max() - label

            iou_bg = iou_pytorch(mask2, label2)
            print("IoU for background: ", iou_bg)
            iou_sum_bg += iou_bg
            iou_count_bg += 1

            writer.add_images(f'input_images',
                              tf.resize(image[[2, 1, 0]], [1080, 1920]),
                              index,
                              dataformats='CHW')

            print("shape of label", label.shape)
            label_reshaped = tf.resize(label.unsqueeze(0),
                                       [1080, 1920]).squeeze()
            print("label reshaped: ", label_reshaped.shape)
            writer.add_images(f'labels',
                              label_reshaped,
                              index,
                              dataformats='HW')
            writer.add_images(
                f'output/1',
                255 - np.asarray(tf.resize(output, [1080, 1920])) * 255,
                index,
                dataformats='HW')
            # writer.add_images(f'output/1',np.asarray(output)*255, index,dataformats='HW')
            # writer.add_images(f'output/2',np.asarray(output_col), index, dataformats='HW')
            writer.add_scalar(f'iou/smoke', iou_fg, index)
            writer.add_scalar(f'iou/background', iou_bg, index)
            writer.add_scalar(f'iou/mean', (iou_bg + iou_fg) / 2, index)
            writer.flush()

    if iou_count_fg > 0:
        print("Mean IoU, foreground: {}".format(iou_sum_fg / iou_count_fg))
        print("Mean IoU, background: {}".format(iou_sum_bg / iou_count_bg))
        print("Mean IoU, averaged over classes: {}".format(
            (iou_sum_fg + iou_sum_bg) / (iou_count_fg + iou_count_bg)))
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        scale=False),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size_target[1], input_size_target[0]),
                         mode='bilinear',
                         align_corners=True)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        _, batch = trainloader_iter.__next__()

        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)
        loss_seg1 = seg_loss(pred1, labels)
        loss_seg2 = seg_loss(pred2, labels)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size
        loss.backward()
        loss_seg_value1 += loss_seg1.item() / args.iter_size
        loss_seg_value2 += loss_seg2.item() / args.iter_size

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))

#         break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    tau = torch.ones(1) * args.tau
    tau = tau.cuda(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params, False)
    elif args.model == 'DeepLabVGG':
        model = DeeplabVGG(pretrained=True, num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    weak_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Resize(1024),
        transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #         RandomCrop(768)
    ])

    target_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Normalize(mean, std)
        #         transforms.Resize(1024),
        #         transforms.ToTensor(),
        #         RandomCrop(768)
    ])

    label_set = GTA5(
        root=args.data_dir,
        num_cls=19,
        split='all',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size,
        #              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )
    unlabel_set = Cityscapes(
        root=args.data_dir_target,
        split=args.set,
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )

    test_set = Cityscapes(
        root=args.data_dir_target,
        split='val',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                       crop_transform=RandomCrop(768)
    )

    label_loader = data.DataLoader(label_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=False)

    unlabel_loader = data.DataLoader(unlabel_set,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=False)

    test_loader = data.DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=False)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    [model, model_D2,
     model_D2], [optimizer, optimizer_D1, optimizer_D2
                 ] = amp.initialize([model, model_D2, model_D2],
                                    [optimizer, optimizer_D1, optimizer_D2],
                                    opt_level="O1",
                                    num_losses=7)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = Interpolate(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = Interpolate(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_test = Interpolate(size=(input_size_target[1],
                                    input_size_target[0]),
                              mode='bilinear',
                              align_corners=True)
    #     interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True)

    normalize_transform = transforms.Compose([
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    # labels for adversarial training
    source_label = 0
    target_label = 1

    max_mIoU = 0

    total_loss_seg_value1 = []
    total_loss_adv_target_value1 = []
    total_loss_D_value1 = []
    total_loss_con_value1 = []

    total_loss_seg_value2 = []
    total_loss_adv_target_value2 = []
    total_loss_D_value2 = []
    total_loss_con_value2 = []

    hist = np.zeros((num_cls, num_cls))

    #     for i_iter in range(args.num_steps):
    for i_iter, (batch, batch_un) in enumerate(
            zip(roundrobin_infinite(label_loader),
                roundrobin_infinite(unlabel_loader))):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_con_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_con_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D1.parameters():
            param.requires_grad = False

        for param in model_D2.parameters():
            param.requires_grad = False

        # train with source

        images, labels = batch
        images_orig = images
        images = transform_batch(images, normalize_transform)
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

#         loss.backward()
        loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

        # train with target

        images_tar, labels_tar = batch_un
        images_tar_orig = images_tar
        images_tar = transform_batch(images_tar, normalize_transform)
        images_tar = Variable(images_tar).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_tar)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_adv_target1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv_target2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()
        loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
        ) / args.iter_size
        loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
        ) / args.iter_size

        # train with consistency loss
        # unsupervise phase
        policies = RandAugment().get_batch_policy(args.batch_size)
        rand_p1 = np.random.random(size=args.batch_size)
        rand_p2 = np.random.random(size=args.batch_size)
        random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2])

        images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1,
                                      rand_p2, random_dir)

        images_aug_orig = images_aug
        images_aug = transform_batch(images_aug, normalize_transform)
        images_aug = Variable(images_aug).cuda(args.gpu)

        pred_target_aug1, pred_target_aug2 = model(images_aug)
        pred_target_aug1 = interp_target(pred_target_aug1)
        pred_target_aug2 = interp_target(pred_target_aug2)

        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1)
        max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1)

        psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32)
        psuedo_label1_thre = psuedo_label1.copy()
        psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies,
                                             rand_p1, rand_p2, random_dir)
        psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32)
        psuedo_label2_thre = psuedo_label2.copy()
        psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies,
                                             rand_p1, rand_p2, random_dir)

        psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu)
        psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu)

        if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre,
                                  args.gpu)
            loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size
        else:
            loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre,
                                  args.gpu)
            loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size
        else:
            loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2
        # proper normalization
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()

# train D

# bring back requires_grad
        for param in model_D1.parameters():
            param.requires_grad = True

        for param in model_D2.parameters():
            param.requires_grad = True

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        D_out1 = model_D1(F.softmax(pred1, dim=1))
        D_out2 = model_D2(F.softmax(pred2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        # train with target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_con_value1,
                    loss_con_value2))

        total_loss_seg_value1.append(loss_seg_value1)
        total_loss_adv_target_value1.append(loss_adv_target_value1)
        total_loss_D_value1.append(loss_D_value1)
        total_loss_con_value1.append(loss_con_value1)

        total_loss_seg_value2.append(loss_seg_value2)
        total_loss_adv_target_value2.append(loss_adv_target_value2)
        total_loss_D_value2.append(loss_D_value2)
        total_loss_con_value2.append(loss_con_value2)

        hist += fast_hist(
            labels.cpu().numpy().flatten().astype(int),
            torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int),
            num_cls)

        if i_iter % 10 == 0:
            print('({}/{})'.format(i_iter + 1, int(args.num_steps)))
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(np.mean(iu)))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_train_acc = acc_overall
            avg_train_loss_seg1 = np.mean(total_loss_seg_value1)
            avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1)
            avg_train_loss_dis1 = np.mean(total_loss_D_value1)
            avg_train_loss_con1 = np.mean(total_loss_con_value1)
            avg_train_loss_seg2 = np.mean(total_loss_seg_value2)
            avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2)
            avg_train_loss_dis2 = np.mean(total_loss_D_value2)
            avg_train_loss_con2 = np.mean(total_loss_con_value2)

            print('avg_train_acc      :', avg_train_acc)
            print('avg_train_loss_seg1 :', avg_train_loss_seg1)
            print('avg_train_loss_adv1 :', avg_train_loss_adv1)
            print('avg_train_loss_dis1 :', avg_train_loss_dis1)
            print('avg_train_loss_con1 :', avg_train_loss_con1)
            print('avg_train_loss_seg2 :', avg_train_loss_seg2)
            print('avg_train_loss_adv2 :', avg_train_loss_adv2)
            print('avg_train_loss_dis2 :', avg_train_loss_dis2)
            print('avg_train_loss_con2 :', avg_train_loss_con2)

            writer['train'].add_scalar('log/mIoU', mIoU, i_iter)
            writer['train'].add_scalar('log/acc', avg_train_acc, i_iter)
            writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2,
                                       i_iter)

            hist = np.zeros((num_cls, num_cls))
            total_loss_seg_value1 = []
            total_loss_adv_target_value1 = []
            total_loss_D_value1 = []
            total_loss_con_value1 = []
            total_loss_seg_value2 = []
            total_loss_adv_target_value2 = []
            total_loss_D_value2 = []
            total_loss_con_value2 = []

            fig = plt.figure(figsize=(15, 15))

            labels = labels[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(331)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(337)
            images = images_orig[0].cpu().numpy().transpose((1, 2, 0))
            #             images += IMG_MEAN
            ax.imshow(images)
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(334)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            labels_tar = labels_tar[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(332)
            ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L')))
            ax.axis("off")
            ax.set_title('tar_labels')

            ax = fig.add_subplot(338)
            ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('tar_datas')

            _, pred_target2 = torch.max(pred_target2, dim=1)
            pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(335)
            ax.imshow(print_palette(
                Image.fromarray(pred_target2).convert('L')))
            ax.axis("off")
            ax.set_title('tar_predicts')

            print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0],
                  'random_dir', random_dir[0])

            psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(333)
            ax.imshow(
                print_palette(
                    Image.fromarray(psuedo_label2_thre).convert('L')))
            ax.axis("off")
            ax.set_title('psuedo_labels')

            ax = fig.add_subplot(339)
            ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('aug_datas')

            _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1)
            pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(336)
            ax.imshow(
                print_palette(Image.fromarray(pred_target_aug2).convert('L')))
            ax.axis("off")
            ax.set_title('aug_predicts')

            #             plt.show()
            writer['train'].add_figure('image/',
                                       fig,
                                       global_step=i_iter,
                                       close=True)

        if i_iter % 500 == 0:
            loss1 = []
            loss2 = []
            for test_i, batch in enumerate(test_loader):

                images, labels = batch
                images_orig = images
                images = transform_batch(images, normalize_transform)
                images = Variable(images).cuda(args.gpu)

                pred1, pred2 = model(images)
                pred1 = interp_test(pred1)
                pred1 = pred1.detach()
                pred2 = interp_test(pred2)
                pred2 = pred2.detach()

                loss_seg1 = loss_calc(pred1, labels, args.gpu)
                loss_seg2 = loss_calc(pred2, labels, args.gpu)
                loss1.append(loss_seg1.item())
                loss2.append(loss_seg2.item())

                hist += fast_hist(
                    labels.cpu().numpy().flatten().astype(int),
                    torch.argmax(pred2,
                                 dim=1).cpu().numpy().flatten().astype(int),
                    num_cls)

            print('test')
            fig = plt.figure(figsize=(15, 15))
            labels = labels[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(311)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(313)
            ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(312)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            #             plt.show()

            writer['test'].add_figure('test_image/',
                                      fig,
                                      global_step=i_iter,
                                      close=True)

            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(mIoU))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_test_loss1 = np.mean(loss1)
            avg_test_loss2 = np.mean(loss2)
            avg_test_acc = acc_overall
            print('avg_test_loss2 :', avg_test_loss1)
            print('avg_test_loss1 :', avg_test_loss2)
            print('avg_test_acc   :', avg_test_acc)
            writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter)
            writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter)
            writer['test'].add_scalar('log/acc', avg_test_acc, i_iter)
            writer['test'].add_scalar('log/mIoU', mIoU, i_iter)

            hist = np.zeros((num_cls, num_cls))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if max_mIoU < mIoU:
            max_mIoU = mIoU
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
def main():
    """Create the model and start the training."""

    gpu_id_2 = 3
    gpu_id_1 = 2

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(args.restore_from)
            saved_state_dict = torch.load(
                'snapshots/GTA2Cityscapes_multi_54/GTA5_10000.pth')
            model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        # model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = model_D1
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.next()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    # print(args.num_steps * args.iter_size * args.batch_size, trainloader.__len__())

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.next()

    # for i in range(200):
    #     _, batch = targetloader_iter.__next__()
    # exit()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(10000, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            # train with source
            # _, batch = trainloader_iter.next()
            _, batch = trainloader_iter.next()
            _, batch_target = targetloader_iter.next()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size
            # print(loss_seg_1, loss_seg_2)

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # output = pred2.cpu().data[0].numpy()
            # real_lab = labels.cpu().data[0].numpy()
            # output = output.transpose(1,2,0)
            # print(real_lab.shape, output.shape)
            # output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            # output_col = colorize_mask(output)
            # real_lab_col = colorize_mask(real_lab)
            # output = Image.fromarray(output)
            # # name[0].split('/')[-1]
            # # print('result/train_seg_result/' + name[0][len(name[0])-23:len(name[0])-4] + '_color.png')
            # output_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_color.png')
            # real_lab_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_real.png')
            # print(loss_seg_value1, loss_seg_value2)
            # if i_iter == 100:
            #     exit()
            # else:
            #     break

            # train with target

            #_, batch = targetloader_iter.next()
            # images, _, _ = target_batch
            # images_target = Variable(images_target).cuda(gpu_id_2)

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D2(fake1_D)

            loss_adv_fake1 = bce_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = bce_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D2(mix2_D)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_adv_mix1 = bce_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = bce_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D2(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D2(mix2_D_)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            # pred_target1 = pred_target1.detach().cuda(gpu_id_1)
            # pred_target2 = pred_target2.detach().cuda(gpu_id_1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D2(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D2(mix2_D__)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
示例#22
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = DeeplabMulti(num_classes=args.num_classes)
    #model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from, map_location='cuda:0')

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)
    #summary(model,(3,7,7))

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    #summary(model_D, (21,321,321))
    #quit()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                      args.iter_size * args.batch_size,
                                      scale=args.random_scale)
    train_dataset_size = len(train_dataset)
    train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                         args.iter_size * args.batch_size,
                                         scale=args.random_scale)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             sampler=train_remain_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         sampler=train_gt_sampler,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                try:
                    pred = interp(model(images))
                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        print("WARNING: out of memory")
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise exception

                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv / args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(
                        semi_ignore_mask.sum()) / semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(
                            pred, semi_gt, args.gpu)
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            try:
                pred = interp(model(images))
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.__next__()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.__next__()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth'))
            #torch.cuda.empty_cache()

    end = timeit.default_timer()
    print(end - start, 'seconds')
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    #model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(512, 1024),
                                                   resize_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)
    scale = 0.9
    testloader3 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')

    for index, img_data in enumerate(zip(testloader, testloader2,
                                         testloader3)):
        batch, batch2, batch3 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        #image3, _, _, name3 = batch3

        inputs = image.cuda()
        inputs2 = image2.cuda()
        #inputs3 = Variable(image3).cuda()
        print('\r>>>>Extracting feature...%03d/%03d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = output1, output2
                #output_batch = interp(sm(output1))
                #output_batch = interp(sm(output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = heatmap_output1 + output1, heatmap_output2 + output2
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
                heatmap_batch = torch.sum(kl_distance(log_sm(heatmap_output1),
                                                      sm(heatmap_output2)),
                                          dim=1)
                heatmap_batch = torch.log(
                    1 + 10 * heatmap_batch)  # for visualization
                heatmap_batch = heatmap_batch.cpu().data.numpy()

                #output1, output2 = model(inputs3)
                #output_batch += interp(sm(0.5* output1 + output2)).cpu().data.numpy()
                #output1, output2 = model(fliplr(inputs3))
                #output1, output2 = fliplr(output1), fliplr(output2)
                #output_batch += interp(sm(0.5 * output1 + output2)).cpu().data.numpy()
                #del output1, output2, inputs3
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        scoremap_batch = np.asarray(np.max(output_batch, axis=3))
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        output_iterator = []
        heatmap_iterator = []
        scoremap_iterator = []

        for i in range(output_batch.shape[0]):
            output_iterator.append(output_batch[i, :, :])
            heatmap_iterator.append(heatmap_batch[i, :, :] /
                                    np.max(heatmap_batch[i, :, :]))
            scoremap_iterator.append(1 - scoremap_batch[i, :, :] /
                                     np.max(scoremap_batch[i, :, :]))
            name_tmp = name[i].split('/')[-1]
            name[i] = '%s/%s' % (args.save, name_tmp)
        with Pool(4) as p:
            p.map(save, zip(output_iterator, name))
            p.map(save_heatmap, zip(heatmap_iterator, name))
            p.map(save_scoremap, zip(scoremap_iterator, name))

        del output_batch

    return args.save
示例#24
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        #model = DeeplabMulti(num_classes=args.num_classes)
        model = DeeplabMulti(num_classes=args.num_classes)
        width = 1024
        srcweight = 3
        model = MDD(width=width,
                    use_bottleneck=False,
                    use_gpu=not args.cpu,
                    class_num=args.num_classes,
                    srcweight=srcweight,
                    args=args)
        model.set_train(True)

    model.c_net.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.get_parameter_list(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    lr_scheduler = INVScheduler(gamma=0.001, decay_rate=0.75, init_lr=0.004)

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        param_groups = model.get_parameter_list()
        group_ratios = [group['lr'] for group in param_groups]
        optimizer = lr_scheduler.next_optimizer(group_ratios, optimizer,
                                                i_iter / 5)
        optimizer.zero_grad()
        #adjust_learning_rate(optimizer, i_iter)

        total_loss = 0
        for sub_i in range(args.iter_size):

            # train with source

            _, batch = trainloader_iter.__next__()

            # load src
            src_images, src_labels, _, _ = batch
            src_images = src_images.to(device)
            src_labels = src_labels.long().to(device)

            # load target
            _, batch = targetloader_iter.__next__()
            tgt_images, _, _ = batch
            tgt_images = tgt_images.to(device)

            inputs = torch.cat((src_images, tgt_images), dim=0)

            loss = model.get_loss(inputs, src_labels)
            loss = loss / args.iter_size
            loss.backward()
            total_loss += loss.item()

        optimizer.step()

        scalar_info = {
            'loss': total_loss,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss = {2:.3f}'.format(
            i_iter, args.num_steps, total_loss))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.c_net.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.c_net.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
示例#25
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           vgg16_caffe_path='./model/vgg16_init.pth',
                           pretrained=True)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=2048).to(device)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=1024).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feature, prediction = model(images)
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward()
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feature_target, _ = model(images)
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward()
        loss_adv_target_value = loss_adv_target.item()

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    if args.tensorboard:
        writer.close()
示例#26
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=8)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)
    model.eval()

    num_classes = 20
    tp_list = [0] * num_classes
    fp_list = [0] * num_classes
    fn_list = [0] * num_classes
    iou_list = [0] * num_classes

    hist = np.zeros((21, 21))
    group = 1
    scorer = SegScorer(num_classes=21)
    datalayer = SSDatalayer(group)
    cos_similarity_func = nn.CosineSimilarity()
    for count in tqdm(range(1000)):
        dat = datalayer.dequeue()
        ref_img = dat['second_img'][0]  # (3, 457, 500)
        query_img = dat['first_img'][0]  # (3, 375, 500)
        query_label = dat['second_label'][0]  # (1, 375, 500)
        ref_label = dat['first_label'][0]  # (1, 457, 500)
        # query_img = dat['second_img'][0]
        # ref_img = dat['first_img'][0]
        # ref_label = dat['second_label'][0]
        # query_label = dat['first_label'][0]
        deploy_info = dat['deploy_info']
        semantic_label = deploy_info['first_semantic_labels'][0][0] - 1  # 2

        ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(
            ref_label).cuda()
        query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(
            query_label[0, :, :]).cuda()
        #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
        #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

        # ref_img = ref_img*ref_label
        ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
        query_label_var, ref_label_var = Variable(query_label), Variable(
            ref_label)

        ref_img_var = torch.unsqueeze(ref_img_var, dim=0)  # [1, 3, 457, 500]
        ref_label_var = torch.unsqueeze(ref_label_var,
                                        dim=1)  # [1, 1, 457, 500]
        query_img_var = torch.unsqueeze(query_img_var,
                                        dim=0)  # [1, 3, 375, 500]
        query_label_var = torch.unsqueeze(query_label_var,
                                          dim=0)  # [1, 375, 500]

        samples = torch.cat([ref_img_var, query_img_var], 0)
        pred = model(samples, ref_label_var)
        w, h = query_label.size()
        pred = F.upsample(pred, size=(w, h), mode='bilinear')  #[2, 416, 416]
        pred = F.softmax(pred, dim=1).squeeze()
        values, pred = torch.max(pred, dim=0)
        #print(pred.shape)
        pred = pred.data.cpu().numpy().astype(np.int32)  # (333, 500)
        #print(pred.shape)
        org_img = get_org_img(
            query_img.squeeze().cpu().data.numpy())  # 查询集的图片(375, 500, 3)
        #print(org_img.shape)
        img = mask_to_img(pred, org_img)  # (375, 500, 3)mask和原图加权后的彩色图片
        cv2.imwrite('save_bins/que_pred/query_set_1_%d.png' % (count), img)

        query_label = query_label.cpu().numpy().astype(np.int32)  # (333, 500)
        class_ind = int(deploy_info['first_semantic_labels'][0][0]
                        ) - 1  # because class indices from 1 in data layer,0
        scorer.update(pred, query_label, class_ind + 1)
        tp, tn, fp, fn = measure(query_label, pred)
        # iou_img = tp/float(max(tn+fp+fn,1))
        tp_list[class_ind] += tp
        fp_list[class_ind] += fp
        fn_list[class_ind] += fn
        # max in case both pred and label are zero
        iou_list = [
            tp_list[ic] /
            float(max(tp_list[ic] + fp_list[ic] + fn_list[ic], 1))
            for ic in range(num_classes)
        ]

        tmp_pred = pred
        tmp_pred[tmp_pred > 0.5] = class_ind + 1
        tmp_gt_label = query_label
        tmp_gt_label[tmp_gt_label > 0.5] = class_ind + 1

        hist += Metrics.fast_hist(tmp_pred, query_label, 21)

    print("-------------GROUP %d-------------" % (group))
    print(iou_list)
    class_indexes = range(group * 5, (group + 1) * 5)
    print('Mean:', np.mean(np.take(iou_list, class_indexes)))
    '''
    for group in range(2):
        datalayer = SSDatalayer(group+1)
        restore(args, model, group+1)

        for count in tqdm(range(1000)):
            dat = datalayer.dequeue()
            ref_img = dat['second_img'][0]#(3, 457, 500)
            query_img = dat['first_img'][0]#(3, 375, 500)
            query_label = dat['second_label'][0]#(1, 375, 500)
            ref_label = dat['first_label'][0]#(1, 457, 500)
            # query_img = dat['second_img'][0]
            # ref_img = dat['first_img'][0]
            # ref_label = dat['second_label'][0]
            # query_label = dat['first_label'][0]
            deploy_info = dat['deploy_info']
            semantic_label = deploy_info['first_semantic_labels'][0][0] - 1#2

            ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(ref_label).cuda()
            query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(query_label[0,:,:]).cuda()
            #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
            #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

            # ref_img = ref_img*ref_label
            ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
            query_label_var, ref_label_var = Variable(query_label), Variable(ref_label)

            ref_img_var = torch.unsqueeze(ref_img_var,dim=0)#[1, 3, 457, 500]
            ref_label_var = torch.unsqueeze(ref_label_var, dim=1)#[1, 1, 457, 500]
            query_img_var = torch.unsqueeze(query_img_var, dim=0)#[1, 3, 375, 500]
            query_label_var = torch.unsqueeze(query_label_var, dim=0)#[1, 375, 500]

            logits  = model(query_img_var, ref_img_var, ref_label_var,ref_label_var)

            # w, h = query_label.size()
            # outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear')
            # out_side = F.softmax(outB_side, dim=1).squeeze()
            # values, pred = torch.max(out_side, dim=0)
            values, pred = model.get_pred(logits, query_img_var)#values[2, 333, 500]
            pred = pred.data.cpu().numpy().astype(np.int32)#(333, 500)

            query_label = query_label.cpu().numpy().astype(np.int32)#(333, 500)
            class_ind = int(deploy_info['first_semantic_labels'][0][0])-1 # because class indices from 1 in data layer,0
            scorer.update(pred, query_label, class_ind+1)
            tp, tn, fp, fn = measure(query_label, pred)
            # iou_img = tp/float(max(tn+fp+fn,1))
            tp_list[class_ind] += tp
            fp_list[class_ind] += fp
            fn_list[class_ind] += fn
            # max in case both pred and label are zero
            iou_list = [tp_list[ic] /
                        float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1))
                        for ic in range(num_classes)]


            tmp_pred = pred
            tmp_pred[tmp_pred>0.5] = class_ind+1
            tmp_gt_label = query_label
            tmp_gt_label[tmp_gt_label>0.5] = class_ind+1

            hist += Metrics.fast_hist(tmp_pred, query_label, 21)


        print("-------------GROUP %d-------------"%(group))
        print(iou_list)
        class_indexes = range(group*5, (group+1)*5)
        print('Mean:', np.mean(np.take(iou_list, class_indexes)))

    print('BMVC IOU', np.mean(np.take(iou_list, range(0,20))))

    miou = Metrics.get_voc_iou(hist)
    print('IOU:', miou, np.mean(miou))
    '''

    binary_hist = np.array((hist[0, 0], hist[0, 1:].sum(), hist[1:, 0].sum(),
                            hist[1:, 1:].sum())).reshape((2, 2))
    bin_iu = np.diag(binary_hist) / (binary_hist.sum(1) + binary_hist.sum(0) -
                                     np.diag(binary_hist))
    print('Bin_iu:', bin_iu)

    scores = scorer.score()
    for k in scores.keys():
        print(k, np.mean(scores[k]), scores[k])
示例#27
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print '%d processd' % index
        image, _, name = batch
        if args.model == 'DeeplabMulti':
            output1, output2 = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
示例#28
0
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='test',
                                     img_root=config.img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        # self.model = DeepLab(num_classes=2,
        #         backbone=config.backbone,
        #         output_stride=config.out_stride,
        #         sync_bn=config.sync_bn,
        #         freeze_bn=config.freeze_bn)
        self.model = DeeplabMulti(num_classes=2)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        # print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
示例#29
0
def main():
    setup_seed(666)
    device = torch.device("cuda")
    save_path = args.save
    save_pseudo_label_path = osp.join(
        save_path,
        'pseudo_label')  # in 'save_path'. Save labelIDs, not trainIDs.
    save_stats_path = osp.join(save_path, 'stats')  # in 'save_path'
    save_lst_path = osp.join(save_path, 'list')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(save_pseudo_label_path):
        os.makedirs(save_pseudo_label_path)
    if not os.path.exists(save_stats_path):
        os.makedirs(save_stats_path)
    if not os.path.exists(save_lst_path):
        os.makedirs(save_lst_path)

    cudnn.enabled = True
    cudnn.benchmark = True

    logger = util.set_logger(args.save, args.log_file, args.debug)
    logger.info('start with arguments %s', args)

    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)
    model.train()
    model.to(device)

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([
        FCDiscriminator(num_classes=num_class_list[i]).train().to(device)
        if i < 1 else OutspaceDiscriminator(
            num_classes=num_class_list[i]).train().to(device) for i in range(2)
    ])
    saved_state_dict_D = torch.load(args.restore_from_D)
    model_D.load_state_dict(saved_state_dict_D)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    image_src_list, _, src_num = parse_split_list(args.data_src_list)
    image_tgt_list, image_name_tgt_list, tgt_num = parse_split_list(
        args.data_tgt_train_list)
    # portions
    tgt_portion = args.init_tgt_port

    # training crop size
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    bce_loss1 = torch.nn.MSELoss()
    bce_loss2 = torch.nn.MSELoss(reduce=False, reduction='none')
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)
    round_idx = 3
    save_round_eval_path = osp.join(args.save, str(round_idx))
    save_pseudo_label_color_path = osp.join(save_round_eval_path,
                                            'pseudo_label_color')
    if not os.path.exists(save_round_eval_path):
        os.makedirs(save_round_eval_path)
    if not os.path.exists(save_pseudo_label_color_path):
        os.makedirs(save_pseudo_label_color_path)
    ########## pseudo-label generation
    # evaluation & save confidence vectors
    test(model, model_D, device, save_round_eval_path, round_idx, 500, args,
         logger)
    conf_dict, pred_cls_num, save_prob_path, save_pred_path = val(
        model, model_D, device, save_round_eval_path, round_idx, tgt_num, args,
        logger)
    # class-balanced thresholds
    cls_thresh = kc_parameters(conf_dict, pred_cls_num, tgt_portion, round_idx,
                               save_stats_path, args, logger)
    # pseudo-label maps generation
    label_selection(cls_thresh, tgt_num, image_name_tgt_list, round_idx,
                    save_prob_path, save_pred_path, save_pseudo_label_path,
                    save_pseudo_label_color_path, save_round_eval_path, args,
                    logger)
    src_train_lst, tgt_train_lst, src_num_sel = savelst_SrcTgt(
        image_tgt_list, image_name_tgt_list, image_src_list, save_lst_path,
        save_pseudo_label_path, src_num, tgt_num, args)
    ########### model retraining
    # dataset
    srctrainset = SrcSTDataSet(args.data_src_dir,
                               src_train_lst,
                               max_iters=args.num_steps * args.batch_size,
                               crop_size=input_size,
                               scale=False,
                               mirror=False,
                               mean=IMG_MEAN)
    tgttrainset = TgtSTDataSet(args.data_tgt_dir,
                               tgt_train_lst,
                               pseudo_root=save_pseudo_label_path,
                               max_iters=args.num_steps * args.batch_size,
                               crop_size=input_size_target,
                               scale=False,
                               mirror=False,
                               mean=IMG_MEAN,
                               set='train')
    trainloader = torch.utils.data.DataLoader(srctrainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              pin_memory=True)
    trainloader_iter = enumerate(trainloader)
    targetloader = torch.utils.data.DataLoader(tgttrainset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    logger.info(
        '###### Start model retraining dataset in round {}! ######'.format(
            round_idx))

    start = timeit.default_timer()
    # start training
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        lamb = 1
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)
        # atten_target = F.interpolate(atten_target, size=(16, 32), mode='bilinear', align_corners=True)

        loss_seg_tgt = seg_loss(pred_target, labels) * lamb

        D_out1 = model_D[0](feat_target)
        loss_adv1 = bce_loss1(
            D_out1,
            torch.FloatTensor(
                D_out1.data.size()).fill_(source_label).to(device))
        D_out2 = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv2 = bce_loss2(
            D_out2,
            torch.FloatTensor(
                D_out2.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv1 * 0.01 + loss_adv2.mean() * 0.01
        loss = loss_seg_tgt + loss_adv
        loss.backward()

        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        D_out_source1 = model_D[0](feat_source.detach())
        loss_D_source1 = bce_loss1(
            D_out_source1,
            torch.FloatTensor(
                D_out_source1.data.size()).fill_(source_label).to(device))
        D_out_source2 = model_D[1](F.softmax(pred_source.detach(), dim=1))
        loss_D_source2 = bce_loss1(
            D_out_source2,
            torch.FloatTensor(
                D_out_source2.data.size()).fill_(source_label).to(device))
        loss_D_source = loss_D_source1 + loss_D_source2
        loss_D_source.backward()

        # train with target
        D_out_target1 = model_D[0](feat_target.detach())
        loss_D_target1 = bce_loss1(
            D_out_target1,
            torch.FloatTensor(
                D_out_target1.data.size()).fill_(target_label).to(device))
        D_out_target2 = model_D[1](F.softmax(pred_target.detach(), dim=1))
        weight_target = bce_loss2(
            D_out_target2,
            torch.FloatTensor(
                D_out_target2.data.size()).fill_(target_label).to(device))
        loss_D_target2 = weight_target.mean()
        loss_D_target = loss_D_target1 + loss_D_target2
        loss_D_target.backward()

        optimizer_D.step()

        if i_iter % 10 == 0:
            print(
                'iter={0:8d}/{1:8d}, seg={2:.3f} seg_tgt={3:.3f} adv={4:.3f} adv1={5:.3f} adv2={6:.3f} src1={7:.3f} src2={8:.3f} tgt1={9:.3f} tgt2={10:.3f} D1={11:.3f} D2={12:.3f}'
                .format(i_iter, args.num_steps, loss_seg.item(),
                        loss_seg_tgt.item(), loss_adv.item(), loss_adv1.item(),
                        loss_adv2.mean().item(), loss_D_source1.item(),
                        loss_D_source2.item(), loss_D_target1.item(),
                        loss_D_target2.item(), loss_D_source.item(),
                        loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            test(model, model_D, device, save_round_eval_path, round_idx, 500,
                 args, logger)
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    logger.info(
        '###### Finish model retraining dataset in round {}! Time cost: {:.2f} seconds. ######'
        .format(round_idx, end - start))
    # test self-trained model in target domain test set
    test(model, model_D, device, save_round_eval_path, round_idx, 500, args,
         logger)
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)

        if args.restore_from is not None:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    # print i_parts
            model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SyntheticSmokeTrain(
        args={},
        dataset_limit=args.num_steps * args.iter_size * args.batch_size,
        image_shape=input_size,
        dataset_mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    targetloader = data.DataLoader(SimpleSmokeVal(args={},
                                                  image_size=input_size_target,
                                                  dataset_mean=IMG_MEAN),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
        # bce_loss_all = torch.nn.BCEWithLogitsLoss(reduction='none')
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
        # bce_loss_all = torch.nn.MSELoss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    # interp_domain = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            for param in interp_domain.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.detach().data.cpu().item(
            ) / args.iter_size
            loss_seg_value2 += loss_seg2.detach().data.cpu().item(
            ) / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            # w1 = torch.argmax(pred_target1.detach(), dim=1)
            # w2 = torch.argmax(pred_target2.detach(), dim=1)

            min_class1 = sorted([(k, v)
                                 for k, v in Counter(w1.ravel()).items()],
                                key=lambda x: x[1])[0][0]
            min_class2 = sorted([(k, v)
                                 for k, v in Counter(w2.ravel()).items()],
                                key=lambda x: x[1])[0][0]

            # m1 = torch.where(w1==min_class1)
            # m1c = torch.where(w1!=min_class1)
            # w1[m1] = 11
            # w1[m1c] = 1

            # m2 = torch.where(w2==min_class2)
            # m2c = torch.where(w2!=min_class2)
            # w2[m2] = 11
            # w2[m2c] = 1

            # D_out1 = interp_domain(D_out1)
            # D_out2 = interp_domain(D_out2)

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.detach().data.cpu(
            ).item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.detach().data.cpu(
            ).item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D2.pth'))
        writer.flush()