def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        # if args.restore_from[:4] == 'http' :
        #     saved_state_dict = model_zoo.load_url(args.restore_from)
        # else:
        #     saved_state_dict = torch.load(args.restore_from)
        #
        # new_params = model.state_dict().copy()
        # for i in saved_state_dict:
        #     # Scale.layer5.conv2d_list.3.weight
        #     i_parts = i.split('.')
        #     # print i_parts
        #     if not args.num_classes == 18 or not i_parts[1] == 'layer5':
        #         new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #         # print i_parts
        # model.load_state_dict(new_params)

        saved_state_dict = torch.load(args.restore_from, map_location=device)
        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            if i in new_params.keys():
                new_params[i] = saved_state_dict[i]
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.spec == False:
        model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
        model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)
    else:
        model_D1 = SpectralDiscriminator(num_classes=args.num_classes).to(device)
        model_D2 = SpectralDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(os.path.join(args.snapshot_dir, args.dir_name)):
        os.makedirs(os.path.join(args.snapshot_dir, args.dir_name))

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size, ignore_label=args.ignore_label, set=args.set, num_classes=args.num_classes),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    # trainloader = data.DataLoader(
    #     SYNTHIADataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size, ignore_label=args.ignore_label, set=args.set, num_classes=args.num_classes),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  ignore_label=args.ignore_label,
    #                                                  set=args.set,
    #                                                  num_classes=args.num_classes),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)
    #
    # targetloader_iter = enumerate(targetloader)

    targetloader = data.DataLoader(IDDDataSet(args.data_dir_target, args.data_list_target,
                                                max_iters=args.num_steps * args.batch_size,
                                                crop_size=input_size_target,
                                                ignore_label=args.ignore_label,
                                                set=args.set, num_classes=args.num_classes),
                                     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                     pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    elif args.gan == 'Hinge' or args.gan == 'new_Hinge':
        adversarial_loss_1 = Hinge(model_D1)
        adversarial_loss_2 = Hinge(model_D2)

    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            if not args.gan == 'new_Hinge':
                loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()

            images, _, _ = batch
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            if args.gan == 'Hinge':
                loss_adv_target1 = adversarial_loss_1(F.softmax(pred_target1), generator=True)
                loss_adv_target2 = adversarial_loss_2(F.softmax(pred_target2), generator=True)
            elif args.gan == 'new_Hinge':
                loss_adv_target1 = adversarial_loss_1(F.softmax(pred_target1), real_samples=F.softmax(pred1),
                                                      generator=True, new_Hinge=True)
                loss_adv_target2 = adversarial_loss_2(F.softmax(pred_target2), real_samples=F.softmax(pred2),
                                                      generator=True, new_Hinge=True)
            else:
                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))

                loss_adv_target2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            if args.gan == 'new_Hinge':
                loss += args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            else:
                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            if args.gan == 'Hinge' or args.gan == 'new_Hinge':
                loss_D1 = adversarial_loss_1(F.softmax(pred_target1), real_samples=F.softmax(pred1), generator=False)
                loss_D2 = adversarial_loss_2(F.softmax(pred_target2), real_samples=F.softmax(pred2), generator=False)

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

            else:
                # train with source
                D_out1 = model_D1(F.softmax(pred1))
                D_out2 = model_D2(F.softmax(pred2))

                loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))

                loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

                # train with target
                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device))

                loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(os.path.join(args.snapshot_dir, args.dir_name)))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, args.dir_name, str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, args.dir_name, str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, args.dir_name, str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, args.dir_name, str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, args.dir_name, str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, args.dir_name, str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
示例#2
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print '%d processd' % index
        image, _, name = batch
        if args.model == 'DeeplabMulti':
            output1, output2 = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
示例#3
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #     model_D1.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D1.pth'))
    #     model_D2.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D2.pth'))

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Load VGG
    #vgg19 = torchvision.models.vgg19(pretrained=True)
    #vgg19.to(device)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(0, args.num_steps):

        loss_seg_value = 0
        #         loss_seg_local_value = 0
        loss_adv_target_value = 0
        #         loss_adv_local_value = 0
        loss_D_value = 0
        #         loss_D_local_value = 0
        loss_local_match_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred_s1, pred_s2, _, _ = model(images)
            #f_s2 = normalize(f_s2)
            pred_s1, pred_s2 = interp(pred_s1), interp(pred_s2)

            loss_seg = args.lambda_seg * seg_loss(pred_s1, labels) + seg_loss(
                pred_s2, labels)
            del labels

            # proper normalization
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target
            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_t1, pred_t2, _, _ = model(images)
            #f_t2 = normalize(f_t2)
            pred_t1, pred_t2 = interp_target(pred_t1), interp_target(pred_t2)
            del images

            D_out_1 = model_D1(F.softmax(pred_t1, dim=1))
            D_out_2 = model_D2(F.softmax(pred_t2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out_1,
                torch.FloatTensor(
                    D_out_1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = bce_loss(
                D_out_2,
                torch.FloatTensor(
                    D_out_2.data.size()).fill_(source_label).to(device))
            loss_adv_target_value += (
                args.lambda_adv_target1 * loss_adv_target1 +
                args.lambda_adv_target *
                loss_adv_target2).item() / args.iter_size

            loss = loss_seg + args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2

            del D_out_1, D_out_2

            #             #< Local patch part>#
            #             corres_id2 = get_correspondance(f_s2, f_t2, pred_s2, pred_t2)

            #             #loss_local1 = local_feature_loss(corres_id1, f_s1, f_t1, model, seg_loss)
            #             loss_local2 = local_feature_loss(corres_id2, labels, f_t2, model, seg_loss)
            #             loss_local = args.lambda_match_target2 * loss_local2 #+args.lambda_match_target1 * loss_local1
            #             loss += loss_local
            #             if corres_id2.nelement() > 0:
            #                 loss_local_match_value += loss_local.item()/ args.iter_size

            loss /= args.iter_size
            loss.backward()

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_s1)), model_D2(
                F.softmax(pred_s2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

            # train with target
            pred_t1, pred_t2 = pred_t1.detach(), pred_t2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_t1)), model_D2(
                F.softmax(pred_t2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 1000 == 0:
            val_dir = '../dataset/Cityscapes/leftImg8bit_trainvaltest/'
            val_list = './dataset/cityscapes_list/val.txt'
            save_dir = './results/tmp'
            gt_dir = '../dataset/Cityscapes/gtFine_trainvaltest/gtFine/val'
            evaluate_cityscapes.test_model(model, device, val_dir, val_list,
                                           save_dir)
            mIoU = compute_iou.mIoUforTest(gt_dir, save_dir)

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                #'loss_seg_local': loss_seg_local_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_local_match': loss_local_match_value,
                'loss_D': loss_D_value,
                'mIoU': mIoU
                #'loss_D_local': loss_D_local_value
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f}, loss_local_match = {5:.3f}, mIoU = {6:3f} '
        #'loss_seg_local = {5:.3f} loss_adv_local = {6:.3f}, loss_D_local = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_local_match_value, mIoU)
                    #loss_seg_local_value, loss_adv_local_value, loss_D_local_value)
        )

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
示例#4
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            )[0] / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes,
                                  use_se=args.use_se,
                                  train_bn=args.train_bn,
                                  norm_style=args.norm_style,
                                  droprate=args.droprate)
            if args.restore_from[:4] == 'http':
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http':
                    if i_parts[1] != 'fc' and i_parts[1] != 'layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] == 'module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D2 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                                 lr=args.learning_rate,
                                 momentum=args.momentum,
                                 nesterov=True,
                                 weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim=1)
        self.log_sm = torch.nn.LogSoftmax(dim=1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size=args.crop_size,
                                  mode='bilinear',
                                  align_corners=True)
        self.interp_target = nn.Upsample(size=args.crop_size,
                                         mode='bilinear',
                                         align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G,
                                                  self.gen_opt,
                                                  opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1,
                                                    self.dis1_opt,
                                                    opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2,
                                                    self.dis2_opt,
                                                    opt_level="O1")
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    def _forward_single(_iter):
        _scalar = torch.tensor(1.).cuda().requires_grad_()

        _, batch = _iter.__next__()
        images, labels, _, _ = batch
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model.forward_irm(images, _scalar)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        # compute penalty
        grad = autograd.grad(loss, [_scalar], create_graph=True)[0]
        penalty = torch.sum(grad**2)

        loss += args.lambda_irm * penalty

        loss.backward()

        return loss_seg1, loss_seg2

    for i_iter in range(args.num_steps):

        loss_seg1_src = 0
        loss_seg2_src = 0

        loss_seg1_tgt = 0
        loss_seg2_tgt = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):

            # def _forward_single(_loader, _model)
            #     _, batch = _loader_iter.next()
            #     images, labels, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred1, pred2 = _model(images)
            #     pred1 = interp(pred1)
            #     pred2 = interp(pred2)

            #     loss_seg1 = loss_calc(pred1, labels, args.gpu)
            #     loss_seg2 = loss_calc(pred2, labels, args.gpu)
            #     loss = loss_seg2 + args.lambda_seg * loss_seg1

            #     # proper normalization
            #     loss = loss / args.iter_size
            #     loss.backward()

            #     return loss_seg1, loss_seg2

            # _, batch = targetloader_iter.next()
            # images, _, _ = batch
            # images = Variable(images).cuda(args.gpu)

            # pred1_tgt, pred2_tgt = model(images)
            # pred1_tgt = interp_target(pred1_tgt)
            # pred2_tgt = interp_target(pred2_tgt)

            # loss_seg1 = loss_calc(pred1, labels, args.gpu)
            # loss_seg2 = loss_calc(pred2, labels, args.gpu)
            # loss = loss_seg2 + args.lambda_seg * loss_seg1
            # loss = loss / args.iter_size
            # loss.backward()

            # train with source
            loss_seg1, loss_seg2 = _forward_single(trainloader_iter)
            loss_seg1_src += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_src += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target
            loss_seg1, loss_seg2 = _forward_single(targetloader_iter)
            loss_seg1_tgt += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_tgt += loss_seg2.data.cpu().numpy()[0] / args.iter_size

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
示例#7
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D = FCDiscriminatorTest(num_classes=2 * args.num_classes).to(device)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D parameters
        restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D)
        model_D.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    """
    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        adv_loss_value = 0
        d_loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate(optimizer_D, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            """
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False
            """

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # size (1, 19, 720, 1280)
            # pdb.set_trace()

            # feature = nn.Softmax(dim=1)(pred1)
            # softmax_out = nn.Softmax(dim=1)(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            _, batch = targetloader_iter.__next__()

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=False)

            images, _, _ = batch
            images = images.to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred_target1, pred_target2 = model(images)

            # pred_target1, 2 == [1, 19, 91, 161]
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            # pred_target1, 2 == [1, 19, 720, 1280]
            # pdb.set_trace()

            # feature_target = nn.Softmax(dim=1)(pred_target1)
            # softmax_out_target = nn.Softmax(dim=1)(pred_target2)

            # features = torch.cat((pred1, pred_target1), dim=0)
            # outputs = torch.cat((pred2, pred_target2), dim=0)
            # features.size() == [2, 19, 720, 1280]
            # softmax_out.size() == [2, 19, 720, 1280]
            # pdb.set_trace()
            # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None)
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                random_layer=None)
            dc_target = torch.FloatTensor(D_out_target.size()).fill_(0).cuda()
            # pdb.set_trace()
            adv_loss = args.lambda_adv * nn.BCEWithLogitsLoss()(D_out_target,
                                                                dc_target)
            adv_loss = adv_loss / args.iter_size
            # pdb.set_trace()
            # classifier_loss = nn.BCEWithLogitsLoss()(pred2,
            #        torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda())
            # pdb.set_trace()
            adv_loss.backward()
            adv_loss_value += adv_loss.item()
            # optimizer_D.step()
            #TODO: normalize loss?

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=True)

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            D_out = CDAN([F.softmax(pred1), F.softmax(pred2)],
                         model_D,
                         random_layer=None)

            dc_source = torch.FloatTensor(D_out.size()).fill_(1).cuda()
            # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None)
            d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                random_layer=None)

            dc_target = torch.FloatTensor(D_out_target.size()).fill_(0).cuda()
            d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            continue

        optimizer.step()
        optimizer_D.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'generator_loss': adv_loss_value,
            'discriminator_loss': d_loss_value,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    adv_loss_value, d_loss_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

            check_original_discriminator(args, pred_target1, pred_target2,
                                         i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
示例#8
0
def main():
    args = get_arguments()

    seed = args.random_seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    input_size = (1024, 512)
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    if args.num_classes == 13:
        name_classes = np.asarray([
            "road", "sidewalk", "building", "light", "sign", "vegetation",
            "sky", "person", "rider", "car", "bus", "motorcycle", "bicycle"
        ])
    elif args.num_classes == 18:
        name_classes = np.asarray([
            "road", "sidewalk", "building", "wall", "fence", "pole", "light",
            "sign", "vegetation", "sky", "person", "rider", "car", "truck",
            "bus", "train", "motorcycle", "bicycle"
        ])
    else:
        NotImplementedError("Unavailable number of classes")

    # Create the model and start the evaluation process
    model = DeeplabMulti(num_classes=args.num_classes)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        saved_state_dict = torch.load('./snapshots/' + args.dir_name + '/' +
                                      str((files + 1) * args.save_pred_every) +
                                      '.pth')
        # saved_state_dict = torch.load('./snapshots/' + '30000.pth')
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        model.eval()
        if args.gta5:
            gta5_loader = torch.utils.data.DataLoader(
                GTA5DataSet(args.data_dir_gta5,
                            args.data_list_gta5,
                            crop_size=input_size,
                            ignore_label=args.ignore_label,
                            set=args.set,
                            num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(gta5_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (GTA5): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.synthia:
            synthia_loader = torch.utils.data.DataLoader(
                SYNTHIADataSet(args.data_dir_synthia,
                               args.data_list_synthia,
                               crop_size=input_size,
                               ignore_label=args.ignore_label,
                               set=args.set,
                               num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(synthia_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (SYNTHIA): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.cityscapes:
            cityscapes_loader = torch.utils.data.DataLoader(
                cityscapesDataSet(args.data_dir_cityscapes,
                                  args.data_list_cityscapes,
                                  crop_size=input_size,
                                  ignore_label=args.ignore_label,
                                  set=args.set,
                                  num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(cityscapes_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (CityScapes): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.idd:
            idd_loader = torch.utils.data.DataLoader(
                IDDDataSet(args.data_dir_idd,
                           args.data_list_idd,
                           crop_size=input_size,
                           ignore_label=args.ignore_label,
                           set=args.set,
                           num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(idd_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (IDD): ' + str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)
示例#9
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from,
                                              map_location='cpu')
    else:
        saved_state_dict = torch.load(args.restore_from, map_location='cpu')
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
示例#10
0
def main():
    """Create the model and start the evaluation process."""

    for i in range(1, 61):
        model_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}.pth'.format(i *
                                                                        2000)
        model_D_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}_D.pth'.format(
            i * 2000)
        save_path = './result/GTA2Cityscapes_{0:d}'.format(i * 2000)
        args = get_arguments()

        gpu0 = args.gpu

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        model = DeeplabMulti(num_classes=args.num_classes)
        saved_state_dict = torch.load(model_path)
        model.load_state_dict(saved_state_dict)
        model.eval()
        model.cuda(gpu0)

        num_class_list = [2048, 19]
        model_D = nn.ModuleList([
            FCDiscriminator(num_classes=num_class_list[i])
            if i < 1 else OutspaceDiscriminator(num_classes=num_class_list[i])
            for i in range(2)
        ])
        model_D.load_state_dict(torch.load(model_D_path))
        model_D.eval()
        model_D.cuda(gpu0)

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        with torch.no_grad():
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, name = batch
                feat, pred = model(
                    Variable(image).cuda(gpu0), model_D, 'target')

                output = interp(pred).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (save_path, name))

                output_col.save('%s/%s_color.png' %
                                (save_path, name.split('.')[0]))

        print(save_path)
示例#11
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    """
    initialize the discriminator
    model_D1 = 
    model_D2 = 
    """

    """
    set the optimizer
    please refer to the main paper section 5 (Network Training)
    use CLASS METHOD optim_parameters to get the parameters of segmentation model (i.e., model.optim_parameters)
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        """
        bce_loss = 
        """
        pass

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G
            """STEP1: train the segmentation model with source GTA5 data
            You should freeze the discriminator during training G
            """

            """STEP2: learn target to source alignment with target Cityscape data            
            """

            # train D
            """STEP3: train the discriminator with source
            You should unfreeze the discriminator
            """

            """STEP4: train the discriminator with target
            """

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print 'save model ...'
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print 'taking snapshot ...'
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def train(gpu, args):
    """Create the model and start the training."""

    rank = args.nr * args.num_gpus + gpu
    if gpu == 1:
        gpu = 3

    dist.init_process_group(backend="nccl",
                            world_size=args.world_size,
                            rank=rank)

    if args.batch_size == 1 and args.use_bn is True:
        raise Exception

    torch.autograd.set_detect_anomaly(True)
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed(args.cuda_seed)

    torch.cuda.set_device(gpu)

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict)
            print("Loaded state dicts for model")
        # if args.restore_from is not None:
        #     new_params = model.state_dict().copy()
        #     for i in saved_state_dict:
        #         # Scale.layer5.conv2d_list.3.weight
        #         i_parts = i.split('.')
        #         # print i_parts
        #         if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #             new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #             # print i_parts
        #     model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    # model.cuda(gpu)
    model = model.cuda(device=gpu)

    if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        model = DistributedDataParallel(model,
                                        device_ids=[gpu],
                                        find_unused_parameters=True)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)
    start_epoch = 0
    if "http" not in args.restore_from and args.restore_from is not None:
        root, extension = args.restore_from.strip().split(".")
        D1pth = root + "_D1." + extension
        D2pth = root + "_D2." + extension
        saved_state_dict = torch.load(D1pth)
        model_D1.load_state_dict(saved_state_dict)
        saved_state_dict = torch.load(D2pth)
        model_D2.load_state_dict(saved_state_dict)
        start_epoch = int(re.findall(r'[\d]+', root)[-1])
        print("Loaded state dict for models D1 and D2")

    model_D1.train()
    # model_D1.cuda(gpu)
    model_D2.train()
    # model_D2.cuda(gpu)

    model_D1 = model_D1.cuda(device=gpu)
    model_D2 = model_D2.cuda(device=gpu)

    if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        model_D1 = DistributedDataParallel(model_D1,
                                           device_ids=[gpu],
                                           find_unused_parameters=True)
        model_D2 = DistributedDataParallel(model_D2,
                                           device_ids=[gpu],
                                           find_unused_parameters=True)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = SyntheticSmokeTrain(args={},
                                        dataset_limit=args.num_steps *
                                        args.iter_size * args.batch_size,
                                        image_shape=input_size,
                                        dataset_mean=IMG_MEAN)

    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=args.world_size,
                                       rank=rank,
                                       shuffle=True)
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

    # trainloader = data.DataLoader(
    #     GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    target_dataset = SimpleSmokeVal(args={},
                                    image_size=input_size_target,
                                    dataset_mean=IMG_MEAN)
    target_sampler = DistributedSampler(target_dataset,
                                        num_replicas=args.world_size,
                                        rank=rank,
                                        shuffle=True)
    targetloader = data.DataLoader(target_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   sampler=target_sampler)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(start_epoch, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.data.cpu().item() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().item() / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().item(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().item(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().item()
            loss_D_value2 += loss_D2.data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().item()
            loss_D_value2 += loss_D2.data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D2.pth'))
        writer.flush()
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if args.num_classes != 19 and i_parts[1] != 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        # if args.num_classes !=19:
        # print i_parts
        model.load_state_dict(new_params)
    start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
    writer_dir = os.path.join("./logs/", args.name, start_time)
    writer = tensorboard.SummaryWriter(writer_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        MulitviewSegLoader(
            num_classes=args.num_classes,
            root=args.data_dir,
            number_views=2,
            view_idx=1,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size,
            # scale=args.random_scale, mirror=args.random_mirror,
            img_mean=IMG_MEAN),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    trainloader_iter = iter(data_loader_cycle(trainloader))

    targetloader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_target,
            num_classes=args.num_classes,
            number_views=1,
            view_idx=0,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    def mdl_val_func(x):
        return interp_target(model(x)[1])

    targetloader_iter = iter(data_loader_cycle(targetloader))
    val_loader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_val,
            number_views=1,
            view_idx=0,
            num_classes=args.num_classes,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    criterion = CrossEntropyLoss2d().cuda(args.gpu)
    valhelper = ValHelper(gpu=args.gpu,
                          model=mdl_val_func,
                          val_loader=val_loader,
                          loss=criterion,
                          writer=writer)
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            batch = next(trainloader_iter)
            images, labels, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            batch = next(targetloader_iter)
            images, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()
        if i_iter % args.val_steps == 0 and i_iter:
            model.eval()
            log = valhelper.valid_epoch(i_iter)
            print('log: {}'.format(log))
            model.train()

        if i_iter % 10 == 0 and i_iter:
            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))
            writer.add_scalar(f'train/loss_seg_value1', loss_seg_value1,
                              i_iter)
            writer.add_scalar(f'train/loss_seg_value2', loss_seg_value2,
                              i_iter)
            writer.add_scalar(f'train/loss_adv_target_value1',
                              loss_adv_target_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value1', loss_D_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Define network
    model = DeeplabMulti(num_classes=args.num_classes)
    # model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)


    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = cityscapesDataSet(max_iters=args.num_steps * args.iter_size * args.batch_size,
                                      scale=args.random_scale)
    train_dataset_size = len(train_dataset)
    train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps * args.iter_size * args.batch_size,
                                         scale=args.random_scale)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, shuffle=False,
                                      num_workers=args.num_workers, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset, sampler=train_remain_sampler, batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset, sampler=train_gt_sampler, batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)


    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    bce_loss_1 = BCEWithLogitsLoss2dd()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')


    # labels for adversarial training
    pred_label = 0
    gt_label = 1


    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                l, _, _ = bce_loss_1(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = args.lambda_semi_adv * -l
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            l, _, _ = bce_loss_1(D_out, make_D_label(gt_label, ignore_mask))
            loss_adv_pred = -l

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()/args.iter_size


            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain), axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            l, p, t = bce_loss_1(D_out, make_D_label(pred_label, ignore_mask))
            gradient_penalty = compute_gradient_penalty(torch.ones_like(pred) * pred_label, pred, model_D)
            loss_D = l + 0.05 * gradient_penalty
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.__next__()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.__next__()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            l, p, t = bce_loss_1(D_out, make_D_label(gt_label, ignore_mask_gt))
            gradient_penalty = compute_gradient_penalty(torch.ones_like(D_gt_v) * gt_label, D_gt_v, model_D)
            loss_D = -l + 0.05 * gradient_penalty
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()



        optimizer.step()
        if sub_i / 5 == 0:
           optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps-1:
            print('save model ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'CITY_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'CITY_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'CITY_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'CITY_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
示例#15
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes,
                             train_bn=False,
                             norm_style='in')
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda()

    testloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                             args.data_list,
                                             crop_size=(640, 1280),
                                             resize_size=(1280, 640),
                                             mean=IMG_MEAN,
                                             scale=False,
                                             mirror=False),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, batch in enumerate(testloader):
        if (index * batchsize) % 100 == 0:
            print('%d processd' % (index * batchsize))
        image, _, _, name = batch
        print(image.shape)

        inputs = Variable(image).cuda()
        if args.model == 'DeeplabMulti':
            output1, output2 = model(inputs)
            output_batch = interp(sm(0.5 * output1 +
                                     output2)).cpu().data.numpy()
            #output1, output2 = model(fliplr(inputs))
            #output2 = fliplr(output2)
            #output_batch += interp(output2).cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            output.save('%s/%s' % (args.save, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (args.save, name_tmp.split('.')[0]))

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from),'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s'%args.model)
    print('NormType:%s'%config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename( os.path.dirname(args.restore_from) )
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes, use_se = config['use_se'], train_bn = False, norm_style = config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(512, 1024), resize_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(round(512*scale), round(1024*scale) ), resize_size=( round(1024*scale), round(512*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4)


    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    sm = torch.nn.Softmax(dim = 1)
    for index, img_data in enumerate(zip(testloader, testloader2) ):
        batch, batch2 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        print(image.shape)

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d'%(index*batchsize, NUM_STEPS), end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5* output1 + output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5* output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        #output_batch = output_batch.transpose(0,2,3,1)
        #output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8)
        output_batch = output_batch.transpose(0,2,3,1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8)
        output_batch[score_batch<3.6] = 255  #3.6 = 4*0.9


        for i in range(output_batch.shape[0]):
            output = output_batch[i,:,:]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo')
            #print(save_path)
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' % (save_path, name_tmp.split('.')[0]))

    return args.save
示例#17
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    #model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    th = 960
    tw = 1280

    testloader = data.DataLoader(robotDataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(th, tw),
                                              resize_size=(tw, th),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False,
                                              set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 0.8
    testloader2 = data.DataLoader(robotDataSet(args.data_dir,
                                               args.data_list,
                                               crop_size=(round(th * scale),
                                                          round(tw * scale)),
                                               resize_size=(round(tw * scale),
                                                            round(th * scale)),
                                               mean=IMG_MEAN,
                                               scale=False,
                                               mirror=False,
                                               set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)
    scale = 0.9
    testloader3 = data.DataLoader(robotDataSet(args.data_dir,
                                               args.data_list,
                                               crop_size=(round(th * scale),
                                                          round(tw * scale)),
                                               resize_size=(round(tw * scale),
                                                            round(th * scale)),
                                               mean=IMG_MEAN,
                                               scale=False,
                                               mirror=False,
                                               set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(960, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(960, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, img_data in enumerate(zip(testloader, testloader2,
                                         testloader3)):
        batch, batch2, batch3 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        image3, _, _, name3 = batch3

        inputs = image.cuda()
        inputs2 = image2.cuda()
        inputs3 = Variable(image3).cuda()
        print('\r>>>>Extracting feature...%03d/%03d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                #output_batch = interp(sm(output1))
                #output_batch = interp(sm(output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs2

                #output1, output2 = model(inputs3)
                #output_batch += interp(sm(0.5* output1 + output2))
                #output1, output2 = model(fliplr(inputs3))
                #output1, output2 = fliplr(output1), fliplr(output2)
                #output_batch += interp(sm(0.5 * output1 + output2))
                #del output1, output2, inputs3

                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        output_iterator = []
        for i in range(output_batch.shape[0]):
            output_iterator.append(output_batch[i, :, :])
            name_tmp = name[i].split('/')[-1]
            name[i] = '%s/%s' % (args.save, name_tmp)
        with Pool(4) as p:
            p.map(save, zip(output_iterator, name))
        del output_batch

    return args.save
def main():
    """Create the model and start the training."""

    gpu_id_2 = 1
    gpu_id_1 = 0
    
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(args.restore_from)
            #saved_state_dict = torch.load('/home/zhangjunyi/hjy_code/AdaptSegNet/snapshots/GTA2Cityscapes_multi_result1/GTA5_60000.pth')
            #model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        # model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = model_D1
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.__next__()

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)
    # print(args.num_steps * args.iter_size * args.batch_size, trainloader.__len__())

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.__next__()

    # for i in range(200):
    #     _, batch = targetloader_iter.____next____()
    # exit()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_, size=(input_size[1], input_size[0]), mode='bilinear', align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_, size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        
        number1 = 0
        number2 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        
        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            beta = 0.95
            # train with source
            # _, batch = trainloader_iter.__next__()
            _, batch = trainloader_iter.__next__()
            _, batch_target = targetloader_iter.__next__()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number1 = torch.sum(labels==255).item()
            loss_seg2, new_labels = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size
            # print(loss_seg_1, loss_seg_2)
            
            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1)
            labels = new_labels
            number2 = torch.sum(labels==255).item()
            loss_seg2, new_lables = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size
            
            pred1_last_target, pred2_last_target, labels_last_target = result_model(batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(batch_target, interp_target)
            
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D2(fake2_D)

            loss_adv_fake1 = bce_loss(D_out_fake_1,
                                       Variable(torch.FloatTensor(D_out_fake_1.data.size()).fill_(source_label)).cuda(
                                            gpu_id_1))

            loss_adv_fake2 = bce_loss(D_out_fake_2,
                                        Variable(torch.FloatTensor(D_out_fake_2.data.size()).fill_(source_label)).cuda(
                                            gpu_id_1))
                                            
            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            
            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(batch_target, interp_target)
            
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)


            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D2(mix2_D)
            
            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            
            loss_adv_mix1 = bce_loss(D_out_mix_1,
                                       Variable(torch.FloatTensor(D_out_mix_1.data.size()).fill_(source_label)).cuda(
                                            gpu_id_1))

            loss_adv_mix2 = bce_loss(D_out_mix_2,
                                        Variable(torch.FloatTensor(D_out_mix_2.data.size()).fill_(source_label)).cuda(
                                            gpu_id_1))

            loss_adv_target1 = loss_adv_mix1*2
            loss_adv_target2 = loss_adv_mix2*2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size
            
            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(batch_last, interp)

            # train with source
            
            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D2(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D2(mix2_D_)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(D_out1_real,
                              Variable(torch.FloatTensor(D_out1_real.data.size()).fill_(source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(D_out2_real,
                               Variable(torch.FloatTensor(D_out2_real.data.size()).fill_(source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(D_out1_mix,
                              Variable(torch.FloatTensor(D_out1_mix.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(D_out2_mix,
                               Variable(torch.FloatTensor(D_out2_mix.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            # pred_target1 = pred_target1.detach().cuda(gpu_id_1)
            # pred_target2 = pred_target2.detach().cuda(gpu_id_1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D2(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D2(mix2_D__)

           
            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(gpu_id_1))
            
            loss_D3 = bce_loss(D_out3,
                              Variable(torch.FloatTensor(D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(D_out4,
                               Variable(torch.FloatTensor(D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1+loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2+loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, number1 = {8}, number2 = {9}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2, number1, number2))

        if i_iter >= args.num_steps_stop - 1:
            print ('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
示例#19
0
def main():
    """Create the model and start the training."""
    setup_seed(666)
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            if i_parts[1]=='layer4' and i_parts[2]=='2':
                i_parts[1] = 'layer5'
                i_parts[2] = '0'
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            else:
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    model.load_state_dict(new_params)
    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(device) if i<1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in range(2)])
    
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)

        loss_adv = 0
        D_out = model_D[0](feat_target)
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        D_out = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv*0.01
        loss_adv.backward()
        
        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        loss_D_source = 0
        D_out_source = model_D[0](feat_source.detach())
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        D_out_source = model_D[1](F.softmax(pred_source.detach(),dim=1))
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        loss_D_source.backward()

        # train with target
        loss_D_target = 0
        D_out_target = model_D[0](feat_target.detach())
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        D_out_target = model_D[1](F.softmax(pred_target.detach(),dim=1))
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        loss_D_target.backward()
        
        optimizer_D.step()

        if i_iter % 10 == 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D_s = {4:.3f}, loss_D_t = {5:.3f}'.format(
            i_iter, args.num_steps, loss_seg.item(), loss_adv.item(), loss_D_source.item(), loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
def main():
    """Create the model and start the training."""
    or_nyu_dict = {
        0: 255,
        1: 16,
        2: 40,
        3: 39,
        4: 7,
        5: 14,
        6: 39,
        7: 12,
        8: 38,
        9: 40,
        10: 10,
        11: 6,
        12: 40,
        13: 39,
        14: 39,
        15: 40,
        16: 18,
        17: 40,
        18: 4,
        19: 40,
        20: 40,
        21: 5,
        22: 40,
        23: 40,
        24: 30,
        25: 36,
        26: 38,
        27: 40,
        28: 3,
        29: 40,
        30: 40,
        31: 9,
        32: 38,
        33: 40,
        34: 40,
        35: 40,
        36: 34,
        37: 37,
        38: 40,
        39: 40,
        40: 39,
        41: 8,
        42: 3,
        43: 1,
        44: 2,
        45: 22
    }
    or_nyu_map = lambda x: or_nyu_dict.get(x, x) - 1
    or_nyu_map = np.vectorize(or_nyu_map)

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    args.or_nyu_map = or_nyu_map
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from == "":
            saved_state_dict = None
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 40 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True
    if args.mode != "baseline" and args.mode != "baseline_tar":
        # init D
        model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
        model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

        model_D1.train()
        model_D1.to(device)

        model_D2.train()
        model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    scale_min = 0.5
    scale_max = 2.0
    rotate_min = -10
    rotate_max = 10
    ignore_label = 255
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    args.width = w
    args.height = h
    train_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.RandScale([scale_min, scale_max]),
        transforms.RandRotate([rotate_min, rotate_max],
                              padding=IMG_MEAN_RGB,
                              ignore_label=ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='rand',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        #et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
        #et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        #et.ExtRandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])

    val_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='center',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])
    if args.mode != "baseline_tar":
        src_train_dst = OpenRoomsSegmentation(root=args.data_dir,
                                              opt=args,
                                              split='train',
                                              transform=train_transform,
                                              imWidth=args.width,
                                              imHeight=args.height,
                                              remap_labels=args.or_nyu_map)
    else:
        src_train_dst = NYU_Labelled(root=args.data_dir_target,
                                     opt=args,
                                     split='train',
                                     transform=train_transform,
                                     imWidth=args.width,
                                     imHeight=args.height,
                                     phase="TRAIN",
                                     randomize=True)
    tar_train_dst = NYU(root=args.data_dir_target,
                        opt=args,
                        split='train',
                        transform=train_transform,
                        imWidth=args.width,
                        imHeight=args.height,
                        phase="TRAIN",
                        randomize=True,
                        mode=args.mode)
    tar_val_dst = NYU(root=args.data_dir,
                      opt=args,
                      split='val',
                      transform=val_transform,
                      imWidth=args.width,
                      imHeight=args.height,
                      phase="TRAIN",
                      randomize=False)
    trainloader = data.DataLoader(src_train_dst,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(tar_train_dst,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    if args.mode != "baseline" and args.mode != "baseline_tar":
        optimizer_D1 = optim.Adam(model_D1.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D1.zero_grad()

        optimizer_D2 = optim.Adam(model_D2.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1] + 1, input_size[0] + 1),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1] + 1,
                                      input_size_target[0] + 1),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_seg_value1_tar = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_seg_value2_tar = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.zero_grad()
            optimizer_D2.zero_grad()
            adjust_learning_rate_D(optimizer_D1, i_iter)
            adjust_learning_rate_D(optimizer_D2, i_iter)
        sample_src = None
        sample_tar = None
        sample_res_src = None
        sample_res_tar = None
        sample_gt_src = None
        sample_gt_tar = None
        for sub_i in range(args.iter_size):

            # train G
            if args.mode != "baseline" and args.mode != "baseline_tar":
                # don't accumulate grads in D
                for param in model_D1.parameters():
                    param.requires_grad = False

                for param in model_D2.parameters():
                    param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels, _ = batch
            sample_src = images.clone()
            sample_gt_src = labels.clone()

            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            sample_pred_src = pred2.detach().cpu()

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target
            try:
                _, batch = targetloader_iter.__next__()
            except:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, tar_labels, _, labelled = batch
            n_labelled = labelled.sum().detach().item()
            batch_size = images.shape[0]
            sample_tar = images.clone()
            sample_gt_tar = tar_labels.clone()
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            #print("N_labelled {}".format(n_labelled))
            if args.mode == "sda" and n_labelled != 0:
                labelled = labelled.to(device) == 1
                tar_labels = tar_labels.to(device)
                loss_seg1_tar = seg_loss(pred_target1[labelled],
                                         tar_labels[labelled])
                loss_seg2_tar = seg_loss(pred_target2[labelled],
                                         tar_labels[labelled])
                loss_tar_labelled = loss_seg2_tar + args.lambda_seg * loss_seg1_tar
                loss_tar_labelled = loss_tar_labelled / args.iter_size
                loss_seg_value1_tar += loss_seg1_tar.item() / args.iter_size
                loss_seg_value2_tar += loss_seg2_tar.item() / args.iter_size
            else:
                loss_tar_labelled = torch.zeros(
                    1, requires_grad=True).float().to(device)
            # proper normalization
            sample_pred_tar = pred_target2.detach().cpu()
            if args.mode != "baseline" and args.mode != "baseline_tar":
                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size + loss_tar_labelled
                #loss = loss_tar_labelled
                loss.backward()
                loss_adv_target_value1 += loss_adv_target1.item(
                ) / args.iter_size
                loss_adv_target_value2 += loss_adv_target2.item(
                ) / args.iter_size
                # train D

                # bring back requires_grad
                for param in model_D1.parameters():
                    param.requires_grad = True

                for param in model_D2.parameters():
                    param.requires_grad = True

                # train with source
                pred1 = pred1.detach()
                pred2 = pred2.detach()

                D_out1 = model_D1(F.softmax(pred1))
                D_out2 = model_D2(F.softmax(pred2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

                # train with target
                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()

                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

        optimizer.step()
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.step()
            optimizer_D2.step()
        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
                'loss_seg1_tar': loss_seg_value1_tar,
                'loss_seg2_tar': loss_seg_value2_tar,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)
            if i_iter % 1000 == 0:
                img = sample_src.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy(
                    np.array(IMG_MEAN_RGB).reshape(1, 3, 1, 1)).float()
                img = img.type(torch.uint8)
                writer.add_images("Src/Images", img, i_iter)
                label = tar_train_dst.decode_target(sample_gt_src).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Labels", label, i_iter)
                preds = sample_pred_src.permute(0, 2, 3, 1).cpu().numpy()
                preds = np.asarray(np.argmax(preds, axis=3), dtype=np.uint8)
                preds = tar_train_dst.decode_target(preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Preds", preds, i_iter)

                tar_img = sample_tar.cpu()[:,
                                           [2, 1, 0], :, :] + torch.from_numpy(
                                               np.array(IMG_MEAN_RGB).reshape(
                                                   1, 3, 1, 1)).float()
                tar_img = tar_img.type(torch.uint8)
                writer.add_images("Tar/Images", tar_img, i_iter)
                tar_label = tar_train_dst.decode_target(
                    sample_gt_tar).transpose(0, 3, 1, 2)
                writer.add_images("Tar/Labels", tar_label, i_iter)
                tar_preds = sample_pred_tar.permute(0, 2, 3, 1).cpu().numpy()
                tar_preds = np.asarray(np.argmax(tar_preds, axis=3),
                                       dtype=np.uint8)
                tar_preds = tar_train_dst.decode_target(tar_preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Tar/Preds", tar_preds, i_iter)
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_seg1_tar={8:.3f} loss_seg2_tar={9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_seg_value1_tar,
                    loss_seg_value2_tar))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'OR_' + str(args.num_steps_stop) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
def main():
    seed = 1338
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.warper == True:
        WarpModel = Warper()
        WarpModel.train()
        WarpModel.to(device)

    model.train()
    model.to(device)
    if args.multi_gpu:
        model = nn.DataParallel(model)

    cudnn.benchmark = True

    if SOURCE_ONLY:
        if not os.path.exists(osp.join(args.snapshot_dir, 'source_only')):
            os.makedirs(osp.join(args.snapshot_dir, 'source_only'))
        # max_iters = None
        max_iters = args.num_steps * args.iter_size * args.batch_size
        trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                                  args.data_list,
                                                  max_iters=max_iters,
                                                  crop_size=input_size,
                                                  scale=args.random_scale,
                                                  mirror=args.random_mirror,
                                                  mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_iter = enumerate(trainloader)

        # implement model.optim_parameters(args) to handle different models' lr setting

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        optimizer.zero_grad()

        seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)

        # set up tensor board
        if args.tensorboard:
            if not os.path.exists(args.log_dir):
                os.makedirs(args.log_dir)

            writer = SummaryWriter(args.log_dir)

        for i_iter in range(args.num_steps):

            loss_seg_value2 = 0

            optimizer.zero_grad()
            adjust_learning_rate(optimizer, i_iter)

            for sub_i in range(args.iter_size):

                # train G
                # train with source

                _, batch = trainloader_iter.__next__()

                images, labels, _, _ = batch
                images = images.to(device)
                labels = labels.long().to(device)

                warper = None
                if args.warper:
                    warper, warp_list = WarpModel(images)
                _, pred2 = model(images, input_size, warper)

                loss_seg2 = seg_loss(pred2, labels)
                loss = loss_seg2

                # proper normalization
                loss = loss / args.iter_size
                loss.backward()
                loss_seg_value2 += loss_seg2.item() / args.iter_size

            optimizer.step()

            if args.tensorboard:
                scalar_info = {
                    'loss_seg2': loss_seg_value2,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print("iter = {0:8d}/{1:8d}, loss_seg2 = {2:.3f}".format(
                i_iter, args.num_steps, loss_seg_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir, 'source_only',
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir, 'source_only',
                             'GTA5_' + str(i_iter) + '.pth'))

        if args.tensorboard:
            writer.close()
    else:
        if args.level == 'single-level':
            # init D
            model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)
            model_D2.train()
            model_D2.to(device)

            if not os.path.exists(osp.join(args.snapshot_dir, 'single_level')):
                os.makedirs(osp.join(args.snapshot_dir, 'single_level'))

            trainloader = data.DataLoader(GTA5DataSet(
                args.data_dir,
                args.data_list,
                max_iters=args.num_steps * args.iter_size * args.batch_size,
                crop_size=input_size,
                scale=args.random_scale,
                mirror=args.random_mirror,
                mean=IMG_MEAN),
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          pin_memory=True)

            trainloader_iter = enumerate(trainloader)

            targetloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target,
                max_iters=args.num_steps * args.iter_size * args.batch_size,
                crop_size=input_size_target,
                scale=False,
                mirror=args.random_mirror,
                mean=IMG_MEAN,
                set=args.set),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           pin_memory=True)

            targetloader_iter = enumerate(targetloader)

            # implement model.optim_parameters(args) to handle different models' lr setting

            optimizer = optim.SGD(model.optim_parameters(args),
                                  lr=args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
            optimizer.zero_grad()

            optimizer_D2 = optim.Adam(model_D2.parameters(),
                                      lr=args.learning_rate_D,
                                      betas=(0.9, 0.99))
            optimizer_D2.zero_grad()

            if args.gan == 'Vanilla':
                bce_loss = torch.nn.BCEWithLogitsLoss()
            elif args.gan == 'LS':
                bce_loss = torch.nn.MSELoss()
            seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

            # labels for adversarial training
            source_label = 0
            target_label = 1

            # set up tensor board
            if args.tensorboard:
                if not os.path.exists(args.log_dir):
                    os.makedirs(args.log_dir)

                writer = SummaryWriter(args.log_dir)

            for i_iter in range(args.num_steps):

                loss_seg_value2 = 0
                loss_adv_target_value2 = 0
                loss_D_value2 = 0

                optimizer.zero_grad()
                adjust_learning_rate(optimizer, i_iter)

                optimizer_D2.zero_grad()
                adjust_learning_rate_D(optimizer_D2, i_iter)

                for sub_i in range(args.iter_size):

                    # train G

                    # don't accumulate grads in D
                    for param in model_D2.parameters():
                        param.requires_grad = False

                    # train with source

                    _, batch = trainloader_iter.__next__()

                    images, labels, _, _ = batch
                    images = images.to(device)
                    labels = labels.long().to(device)

                    warper = None
                    if args.warper:
                        warper, warp_list = WarpModel(images)

                    _, pred2 = model(images, input_size, warper)

                    loss_seg2 = seg_loss(pred2, labels)
                    loss = loss_seg2

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()
                    loss_seg_value2 += loss_seg2.item() / args.iter_size

                    # train with target

                    _, batch = targetloader_iter.__next__()
                    images, _, _ = batch
                    images = images.to(device)

                    _, pred_target2 = model(images, input_size, warper)

                    D_out2 = model_D2(F.softmax(pred_target2))

                    loss_adv_target2 = bce_loss(
                        D_out2,
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label).to(device))

                    loss = args.lambda_adv_target2 * loss_adv_target2
                    loss = loss / args.iter_size
                    loss.backward()
                    loss_adv_target_value2 += loss_adv_target2.item(
                    ) / args.iter_size

                    # train D

                    # bring back requires_grad
                    for param in model_D2.parameters():
                        param.requires_grad = True

                    # train with source
                    pred2 = pred2.detach()

                    D_out2 = model_D2(F.softmax(pred2))

                    loss_D2 = bce_loss(
                        D_out2,
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label).to(device))
                    loss_D2 = loss_D2 / args.iter_size / 2

                    loss_D2.backward()

                    loss_D_value2 += loss_D2.item()

                    # train with target
                    pred_target2 = pred_target2.detach()

                    D_out2 = model_D2(F.softmax(pred_target2))

                    loss_D2 = bce_loss(
                        D_out2,
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(target_label).to(device))

                    loss_D2 = loss_D2 / args.iter_size / 2

                    loss_D2.backward()
                    loss_D_value2 += loss_D2.item()

                optimizer.step()
                optimizer_D2.step()

                if args.tensorboard:
                    scalar_info = {
                        'loss_seg2': loss_seg_value2,
                        'loss_adv_target2': loss_adv_target_value2,
                        'loss_D2': loss_D_value2,
                    }

                    if i_iter % 10 == 0:
                        for key, val in scalar_info.items():
                            writer.add_scalar(key, val, i_iter)

                print('exp = {}'.format(args.snapshot_dir))
                print(
                    'iter = {0:8d}/{1:8d}, loss_seg2 = {2:.3f} loss_adv2 = {3:.3f} loss_D2 = {4:.3f}'
                    .format(i_iter, args.num_steps, loss_seg_value2,
                            loss_adv_target_value2, loss_D_value2))

                if i_iter >= args.num_steps_stop - 1:
                    print('save model ...')
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir, 'single_level',
                                 'GTA5_' + str(args.num_steps_stop) + '.pth'))
                    torch.save(
                        model_D2.state_dict(),
                        osp.join(
                            args.snapshot_dir, 'single_level',
                            'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
                    break

                if i_iter % args.save_pred_every == 0 and i_iter != 0:
                    print('taking snapshot ...')
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir, 'single_level',
                                 'GTA5_' + str(i_iter) + '.pth'))
                    torch.save(
                        model_D2.state_dict(),
                        osp.join(args.snapshot_dir, 'single_level',
                                 'GTA5_' + str(i_iter) + '_D2.pth'))

            if args.tensorboard:
                writer.close()

        elif args.level == 'multi-level':
            # init D
            model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
            model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

            model_D1.train()
            model_D1.to(device)

            model_D2.train()
            model_D2.to(device)

            if not os.path.exists(osp.join(args.snapshot_dir, 'multi_level')):
                os.makedirs(osp.join(args.snapshot_dir, 'multi_level'))

            trainloader = data.DataLoader(GTA5DataSet(
                args.data_dir,
                args.data_list,
                max_iters=args.num_steps * args.iter_size * args.batch_size,
                crop_size=input_size,
                scale=args.random_scale,
                mirror=args.random_mirror,
                mean=IMG_MEAN),
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          pin_memory=True)

            trainloader_iter = enumerate(trainloader)

            targetloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target,
                max_iters=args.num_steps * args.iter_size * args.batch_size,
                crop_size=input_size_target,
                scale=False,
                mirror=args.random_mirror,
                mean=IMG_MEAN,
                set=args.set),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           pin_memory=True)

            targetloader_iter = enumerate(targetloader)

            # implement model.optim_parameters(args) to handle different models' lr setting

            optimizer = optim.SGD(model.optim_parameters(args),
                                  lr=args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
            optimizer.zero_grad()

            optimizer_D1 = optim.Adam(model_D1.parameters(),
                                      lr=args.learning_rate_D,
                                      betas=(0.9, 0.99))
            optimizer_D1.zero_grad()

            optimizer_D2 = optim.Adam(model_D2.parameters(),
                                      lr=args.learning_rate_D,
                                      betas=(0.9, 0.99))
            optimizer_D2.zero_grad()

            if args.gan == 'Vanilla':
                bce_loss = torch.nn.BCEWithLogitsLoss()
            elif args.gan == 'LS':
                bce_loss = torch.nn.MSELoss()
            seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

            # labels for adversarial training
            source_label = 0
            target_label = 1

            # set up tensor board
            if args.tensorboard:
                if not os.path.exists(args.log_dir):
                    os.makedirs(args.log_dir)

                writer = SummaryWriter(args.log_dir)

            for i_iter in range(args.num_steps):

                loss_seg_value1 = 0
                loss_adv_target_value1 = 0
                loss_D_value1 = 0

                loss_seg_value2 = 0
                loss_adv_target_value2 = 0
                loss_D_value2 = 0

                optimizer.zero_grad()
                adjust_learning_rate(optimizer, i_iter)

                optimizer_D1.zero_grad()
                optimizer_D2.zero_grad()
                adjust_learning_rate_D(optimizer_D1, i_iter)
                adjust_learning_rate_D(optimizer_D2, i_iter)

                for sub_i in range(args.iter_size):

                    # train G

                    # don't accumulate grads in D
                    for param in model_D1.parameters():
                        param.requires_grad = False

                    for param in model_D2.parameters():
                        param.requires_grad = False

                    # train with source

                    _, batch = trainloader_iter.__next__()

                    images, labels, _, _ = batch
                    images = images.to(device)
                    labels = labels.long().to(device)

                    pred1, pred2 = model(images)

                    loss_seg1 = seg_loss(pred1, labels)
                    loss_seg2 = seg_loss(pred2, labels)
                    loss = loss_seg2 + args.lambda_seg * loss_seg1

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()
                    loss_seg_value1 += loss_seg1.item() / args.iter_size
                    loss_seg_value2 += loss_seg2.item() / args.iter_size

                    # train with target

                    _, batch = targetloader_iter.__next__()
                    images, _, _ = batch
                    images = images.to(device)

                    pred_target1, pred_target2 = model(images)

                    D_out1 = model_D1(F.softmax(pred_target1))
                    D_out2 = model_D2(F.softmax(pred_target2))

                    loss_adv_target1 = bce_loss(
                        D_out1,
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label).to(device))

                    loss_adv_target2 = bce_loss(
                        D_out2,
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label).to(device))

                    loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                    loss = loss / args.iter_size
                    loss.backward()
                    loss_adv_target_value1 += loss_adv_target1.item(
                    ) / args.iter_size
                    loss_adv_target_value2 += loss_adv_target2.item(
                    ) / args.iter_size

                    # train D

                    # bring back requires_grad
                    for param in model_D1.parameters():
                        param.requires_grad = True

                    for param in model_D2.parameters():
                        param.requires_grad = True

                    # train with source
                    pred1 = pred1.detach()
                    pred2 = pred2.detach()

                    D_out1 = model_D1(F.softmax(pred1))
                    D_out2 = model_D2(F.softmax(pred2))

                    loss_D1 = bce_loss(
                        D_out1,
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label).to(device))

                    loss_D2 = bce_loss(
                        D_out2,
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label).to(device))

                    loss_D1 = loss_D1 / args.iter_size / 2
                    loss_D2 = loss_D2 / args.iter_size / 2

                    loss_D1.backward()
                    loss_D2.backward()

                    loss_D_value1 += loss_D1.item()
                    loss_D_value2 += loss_D2.item()

                    # train with target
                    pred_target1 = pred_target1.detach()
                    pred_target2 = pred_target2.detach()

                    D_out1 = model_D1(F.softmax(pred_target1))
                    D_out2 = model_D2(F.softmax(pred_target2))

                    loss_D1 = bce_loss(
                        D_out1,
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(target_label).to(device))

                    loss_D2 = bce_loss(
                        D_out2,
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(target_label).to(device))

                    loss_D1 = loss_D1 / args.iter_size / 2
                    loss_D2 = loss_D2 / args.iter_size / 2

                    loss_D1.backward()
                    loss_D2.backward()

                    loss_D_value1 += loss_D1.item()
                    loss_D_value2 += loss_D2.item()

                optimizer.step()
                optimizer_D1.step()
                optimizer_D2.step()

                if args.tensorboard:
                    scalar_info = {
                        'loss_seg1': loss_seg_value1,
                        'loss_seg2': loss_seg_value2,
                        'loss_adv_target1': loss_adv_target_value1,
                        'loss_adv_target2': loss_adv_target_value2,
                        'loss_D1': loss_D_value1,
                        'loss_D2': loss_D_value2,
                    }

                    if i_iter % 10 == 0:
                        for key, val in scalar_info.items():
                            writer.add_scalar(key, val, i_iter)

                print('exp = {}'.format(args.snapshot_dir))
                print(
                    'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                    .format(i_iter, args.num_steps, loss_seg_value1,
                            loss_seg_value2, loss_adv_target_value1,
                            loss_adv_target_value2, loss_D_value1,
                            loss_D_value2))

                if i_iter >= args.num_steps_stop - 1:
                    print('save model ...')
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir, 'multi_level',
                                 'GTA5_' + str(args.num_steps_stop) + '.pth'))
                    torch.save(
                        model_D1.state_dict(),
                        osp.join(
                            args.snapshot_dir, 'multi_level',
                            'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                    torch.save(
                        model_D2.state_dict(),
                        osp.join(
                            args.snapshot_dir, 'multi_level',
                            'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
                    break

                if i_iter % args.save_pred_every == 0 and i_iter != 0:
                    print('taking snapshot ...')
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir, 'multi_level',
                                 'GTA5_' + str(i_iter) + '.pth'))
                    torch.save(
                        model_D1.state_dict(),
                        osp.join(args.snapshot_dir, 'multi_level',
                                 'GTA5_' + str(i_iter) + '_D1.pth'))
                    torch.save(
                        model_D2.state_dict(),
                        osp.join(args.snapshot_dir, 'multi_level',
                                 'GTA5_' + str(i_iter) + '_D2.pth'))

            if args.tensorboard:
                writer.close()
        else:
            raise NotImplementedError(
                'level choice {} is not implemented'.format(args.level))
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    nyu_nyu_dict = {11:255, 13:255, 15:255, 17:255, 19:255, 20:255, 21: 255, 23: 255, 
            24:255, 25:255, 26:255, 27:255, 28:255, 29:255, 31:255, 32:255, 33:255}
    nyu_nyu_map = lambda x: nyu_nyu_dict.get(x+1,x)
    nyu_nyu_map = np.vectorize(nyu_nyu_map)
    args.nyu_nyu_map = nyu_nyu_map
    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    metrics = StreamSegMetrics(args.num_classes)
    metrics_remap = StreamSegMetrics(args.num_classes)
    ignore_label = 255
    value_scale = 255 
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transforms.Compose([
	    # et.ExtResize( 512 ),
	    transforms.Crop([args.height+1, args.width+1], crop_type='center', padding=IMG_MEAN, ignore_label=ignore_label),
	    transforms.ToTensor(),
	    transforms.Normalize(mean=IMG_MEAN,
	    	    std=[1, 1, 1]),
	])
    val_dst = NYU(root=args.data_dir, opt=args,
			 split='val', transform=val_transform,
			 imWidth = args.width, imHeight = args.height, phase="TEST",
			 randomize = False)
    print("Dset Length {}".format(len(val_dst)))
    testloader = data.DataLoader(val_dst,
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(args.height+1, args.width+1), mode='bilinear', align_corners=True)
    metrics.reset()
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, targets, name = batch
        image = image.to(device)
        print(index)
        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()
        targets = targets.cpu().numpy()
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        preds = output[None,:,:]
        #input_ = image.cpu().numpy()[0].transpose(1,2,0) + np.array(IMG_MEAN)
        metrics.update(targets, preds)
        targets = args.nyu_nyu_map(targets)
        preds = args.nyu_nyu_map(preds)
        metrics_remap.update(targets,preds)
        #input_ = Image.fromarray(input_.astype(np.uint8))
        #output_col = colorize_mask(output)
        #output = Image.fromarray(output)
        
        #name = name[0].split('/')[-1]
        #input_.save('%s/%s' % (args.save, name))
        #output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
    print(metrics.get_results())
    print(metrics_remap.get_results())
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    print(model)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(BDDDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(960, 540),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False,
                                            set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)  # 960 540

    interp = nn.Upsample(size=(720, 1280), mode='bilinear', align_corners=True)

    if args.save_confidence:
        select = open('list.txt', 'w')
        c_list = []

    for index, batch in enumerate(testloader):
        if index % 10 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        output = model(image)

        if args.save_confidence:
            confidence = get_confidence(output)
            confidence = confidence.cpu().item()
            c_list.append([confidence, name])
            name = name[0].split('/')[-1]
            save_path = '%s/%s_c.txt' % (args.save, name.split('.')[0])
            record = open(save_path, 'w')
            record.write('%.5f' % confidence)
            record.close()
        else:
            name = name[0].split('/')[-1]

        output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        output.save('%s/%s' % (args.save, name[:-4] + '.png'))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    def takeFirst(elem):
        return elem[0]

    if args.save_confidence:
        c_list.sort(key=takeFirst, reverse=True)
        length = len(c_list)
        for i in range(length // 3):
            print(c_list[i][0])
            print(c_list[i][1])
            select.write(c_list[i][1][0])
            select.write('\n')
        select.close()

    print(args.save)
class AD_Trainer(nn.Module):
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes,
                                  use_se=args.use_se,
                                  train_bn=args.train_bn,
                                  norm_style=args.norm_style,
                                  droprate=args.droprate)
            if args.restore_from[:4] == 'http':
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http':
                    if i_parts[1] != 'fc' and i_parts[1] != 'layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] == 'module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D2 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                                 lr=args.learning_rate,
                                 momentum=args.momentum,
                                 nesterov=True,
                                 weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim=1)
        self.log_sm = torch.nn.LogSoftmax(dim=1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size=args.crop_size,
                                  mode='bilinear',
                                  align_corners=True)
        self.interp_target = nn.Upsample(size=args.crop_size,
                                         mode='bilinear',
                                         align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G,
                                                  self.gen_opt,
                                                  opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1,
                                                    self.dis1_opt,
                                                    opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2,
                                                    self.dis2_opt,
                                                    opt_level="O1")

    def update_class_criterion(self, labels):
        weight = torch.FloatTensor(self.num_classes).zero_().cuda()
        weight += 1
        count = torch.FloatTensor(self.num_classes).zero_().cuda()
        often = torch.FloatTensor(self.num_classes).zero_().cuda()
        often += 1
        print(labels.shape)
        n, h, w = labels.shape
        for i in range(self.num_classes):
            count[i] = torch.sum(labels == i)
            if count[i] < 64 * 64 * n:  #small objective
                weight[i] = self.max_value
        if self.often_balance:
            often[count == 0] = self.max_value

        self.often_weight = 0.9 * self.often_weight + 0.1 * often
        self.class_weight = weight * self.often_weight
        print(self.class_weight)
        return nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255)

    def update_label(self, labels, prediction):
        criterion = nn.CrossEntropyLoss(weight=self.class_weight,
                                        ignore_index=255,
                                        reduction='none')
        #criterion = self.seg_loss
        loss = criterion(prediction, labels)
        print('original loss: %f' % self.seg_loss(prediction, labels))
        #mm = torch.median(loss)
        loss_data = loss.data.cpu().numpy()
        mm = np.percentile(loss_data[:], self.only_hard_label)
        #print(m.data.cpu(), mm)
        labels[loss < mm] = 255
        return labels

    def update_variance(self, labels, pred1, pred2):
        criterion = nn.CrossEntropyLoss(weight=self.class_weight,
                                        ignore_index=255,
                                        reduction='none')
        kl_distance = nn.KLDivLoss(reduction='none')
        loss = criterion(pred1, labels)

        #n, h, w = labels.shape
        #labels_onehot = torch.zeros(n, self.num_classes, h, w)
        #labels_onehot = labels_onehot.cuda()
        #labels_onehot.scatter_(1, labels.view(n,1,h,w), 1)

        variance = torch.sum(kl_distance(self.log_sm(pred1), self.sm(pred2)),
                             dim=1)
        exp_variance = torch.exp(-variance)
        #variance = torch.log( 1 + (torch.mean((pred1-pred2)**2, dim=1)))
        #torch.mean( kl_distance(self.log_sm(pred1),pred2), dim=1) + 1e-6
        print(variance.shape)
        print('variance mean: %.4f' % torch.mean(exp_variance[:]))
        print('variance min: %.4f' % torch.min(exp_variance[:]))
        print('variance max: %.4f' % torch.max(exp_variance[:]))
        #loss = torch.mean(loss/variance) + torch.mean(variance)
        loss = torch.mean(loss * exp_variance) + torch.mean(variance)
        return loss

    def gen_update(self, images, images_t, labels, labels_t, i_iter):
        self.gen_opt.zero_grad()

        pred1, pred2 = self.G(images)
        pred1 = self.interp(pred1)
        pred2 = self.interp(pred2)

        if self.class_balance:
            self.seg_loss = self.update_class_criterion(labels)

        loss_seg1 = self.update_variance(labels, pred1, pred2)
        loss_seg2 = self.update_variance(labels, pred2, pred1)

        loss = loss_seg2 + self.lambda_seg * loss_seg1

        # target
        pred_target1, pred_target2 = self.G(images_t)
        pred_target1 = self.interp_target(pred_target1)
        pred_target2 = self.interp_target(pred_target2)

        if self.multi_gpu:
            #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0:
            loss_adv_target1 = self.D1.module.calc_gen_loss(
                self.D1, input_fake=F.softmax(pred_target1, dim=1))
            loss_adv_target2 = self.D2.module.calc_gen_loss(
                self.D2, input_fake=F.softmax(pred_target2, dim=1))
            #else:
            #    print('skip the discriminator')
            #    loss_adv_target1, loss_adv_target2 = 0, 0
        else:
            #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0:
            loss_adv_target1 = self.D1.calc_gen_loss(self.D1,
                                                     input_fake=F.softmax(
                                                         pred_target1, dim=1))
            loss_adv_target2 = self.D2.calc_gen_loss(self.D2,
                                                     input_fake=F.softmax(
                                                         pred_target2, dim=1))
            #else:
            #loss_adv_target1 = 0.0 #torch.tensor(0).cuda()
            #loss_adv_target2 = 0.0 #torch.tensor(0).cuda()

        loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2

        if i_iter < 15000:
            self.lambda_kl_target_copy = 0
            self.lambda_me_target_copy = 0
        else:
            self.lambda_kl_target_copy = self.lambda_kl_target
            self.lambda_me_target_copy = self.lambda_me_target

        loss_me = 0.0
        if self.lambda_me_target_copy > 0:
            confidence_map = torch.sum(
                self.sm(0.5 * pred_target1 + pred_target2)**2, 1).detach()
            loss_me = -torch.mean(confidence_map * torch.sum(
                self.sm(0.5 * pred_target1 + pred_target2) *
                self.log_sm(0.5 * pred_target1 + pred_target2), 1))
            loss += self.lambda_me_target * loss_me

        loss_kl = 0.0
        if self.lambda_kl_target_copy > 0:
            n, c, h, w = pred_target1.shape
            with torch.no_grad():
                #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t))
                #pred_target1_flip = self.interp_target(pred_target1_flip)
                #pred_target2_flip = self.interp_target(pred_target2_flip)
                mean_pred = self.sm(
                    0.5 * pred_target1 + pred_target2
                )  #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2
            loss_kl = (self.kl_loss(self.log_sm(pred_target2), mean_pred) +
                       self.kl_loss(self.log_sm(pred_target1), mean_pred)) / (
                           n * h * w)
            #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w)
            print(loss_kl)
            loss += self.lambda_kl_target * loss_kl

        if self.fp16:
            with amp.scale_loss(loss, self.gen_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.gen_opt.step()

        val_loss = self.seg_loss(pred_target2, labels_t)

        return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss

    def dis_update(self, pred1, pred2, pred_target1, pred_target2):
        self.dis1_opt.zero_grad()
        self.dis2_opt.zero_grad()
        pred1 = pred1.detach()
        pred2 = pred2.detach()
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        if self.multi_gpu:
            loss_D1, reg1 = self.D1.module.calc_dis_loss(
                self.D1,
                input_fake=F.softmax(pred_target1, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
            loss_D2, reg2 = self.D2.module.calc_dis_loss(
                self.D2,
                input_fake=F.softmax(pred_target2, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
        else:
            loss_D1, reg1 = self.D1.calc_dis_loss(
                self.D1,
                input_fake=F.softmax(pred_target1, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
            loss_D2, reg2 = self.D2.calc_dis_loss(
                self.D2,
                input_fake=F.softmax(pred_target2, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))

        loss = loss_D1 + loss_D2
        if self.fp16:
            with amp.scale_loss(loss,
                                [self.dis1_opt, self.dis2_opt]) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.dis1_opt.step()
        self.dis2_opt.step()
        return loss_D1, loss_D2