Beispiel #1
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=input_size, mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        if args.model == 'DeeplabMulti':
            output1, output2,_,_ = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
Beispiel #2
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           vgg16_caffe_path='./model/vgg16_init.pth',
                           pretrained=True)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=2048).to(device)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=1024).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feature, prediction = model(images)
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward()
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feature_target, _ = model(images)
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward()
        loss_adv_target_value = loss_adv_target.item()

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    if args.tensorboard:
        writer.close()
Beispiel #3
0
def main():
    """Create the model and start the training."""
    global args
    args = get_arguments()
    if args.dist:
        init_dist(args.launcher, backend=args.backend)
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'Deeplab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from, strict=False)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict, strict=False)
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict, strict=False)
            del saved_state_dict

    model.train()
    model.to(device)
    if args.dist:
        broadcast_params(model)

    if rank == 0:
        print(model)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)
    if args.dist:
        broadcast_params(model_D1)
    if args.restore_D is not None:
        D_dict = torch.load(args.restore_D)
        model_D1.load_state_dict(D_dict, strict=False)
        del D_dict

    model_D2.train()
    model_D2.to(device)
    if args.dist:
        broadcast_params(model_D2)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_data = GTA5BDDDataSet(args.data_dir,
                                args.data_list,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size,
                                scale=args.random_scale,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)
    trainloader = data.DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False if train_sampler else True,
                                  num_workers=args.num_workers,
                                  pin_memory=False,
                                  sampler=train_sampler)

    trainloader_iter = enumerate(cycle(trainloader))

    target_data = BDDDataSet(args.data_dir_target,
                             args.data_list_target,
                             max_iters=args.num_steps * args.iter_size *
                             args.batch_size,
                             crop_size=input_size_target,
                             scale=False,
                             mirror=args.random_mirror,
                             mean=IMG_MEAN,
                             set=args.set)
    target_sampler = None
    if args.dist:
        target_sampler = DistributedSampler(target_data)
    targetloader = data.DataLoader(target_data,
                                   batch_size=args.batch_size,
                                   shuffle=False if target_sampler else True,
                                   num_workers=args.num_workers,
                                   pin_memory=False,
                                   sampler=target_sampler)

    targetloader_iter = enumerate(cycle(targetloader))

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard and rank == 0:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    torch.cuda.empty_cache()
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, size, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            interp = nn.Upsample(size=(size[1], size[0]),
                                 mode='bilinear',
                                 align_corners=True)

            pred1 = model(images)
            pred1 = interp(pred1)

            loss_seg1 = seg_loss(pred1, labels)

            loss = loss_seg1

            # proper normalization
            loss = loss / args.iter_size / world_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size

            _, batch = targetloader_iter.__next__()
            # train with target
            images, _, _ = batch
            images = images.to(device)

            pred_target1 = model(images)
            pred_target1 = interp_target(pred_target1)

            D_out1 = model_D1(F.softmax(pred_target1))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1
            loss = loss / args.iter_size / world_size

            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            D_out1 = model_D1(F.softmax(pred1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with target
            pred_target1 = pred_target1.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            if args.dist:
                average_gradients(model)
                average_gradients(model_D1)
                average_gradients(model_D2)

            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        if rank == 0:
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1 * world_size,
                    'loss_D2': loss_D_value2 * world_size,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D1.pth'))
    print(args.snapshot_dir)
    if args.tensorboard and rank == 0:
        writer.close()
Beispiel #4
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            continue

        optimizer.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_noadapt_' + str(args.num_steps_stop) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        scale=False),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size_target[1], input_size_target[0]),
                         mode='bilinear',
                         align_corners=True)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        _, batch = trainloader_iter.__next__()

        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)
        loss_seg1 = seg_loss(pred1, labels)
        loss_seg2 = seg_loss(pred2, labels)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size
        loss.backward()
        loss_seg_value1 += loss_seg1.item() / args.iter_size
        loss_seg_value2 += loss_seg2.item() / args.iter_size

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))

#         break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
Beispiel #6
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #### restore model_D1, D2 and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D1 parameters
        restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D1)
        model_D1.load_state_dict(saved_state_dict)

        # model_D2 parameters
        restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D2)
        model_D2.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            pdb.set_trace()
            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)
            pdb.set_trace()
            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            pdb.set_trace()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'loss_adv_target1': loss_adv_target_value1,
            'loss_adv_target2': loss_adv_target_value2,
            'loss_D1': loss_D_value1,
            'loss_D2': loss_D_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=8)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)
    model.eval()

    num_classes = 20
    tp_list = [0] * num_classes
    fp_list = [0] * num_classes
    fn_list = [0] * num_classes
    iou_list = [0] * num_classes

    hist = np.zeros((21, 21))
    group = 1
    scorer = SegScorer(num_classes=21)
    datalayer = SSDatalayer(group)
    cos_similarity_func = nn.CosineSimilarity()
    for count in tqdm(range(1000)):
        dat = datalayer.dequeue()
        ref_img = dat['second_img'][0]  # (3, 457, 500)
        query_img = dat['first_img'][0]  # (3, 375, 500)
        query_label = dat['second_label'][0]  # (1, 375, 500)
        ref_label = dat['first_label'][0]  # (1, 457, 500)
        # query_img = dat['second_img'][0]
        # ref_img = dat['first_img'][0]
        # ref_label = dat['second_label'][0]
        # query_label = dat['first_label'][0]
        deploy_info = dat['deploy_info']
        semantic_label = deploy_info['first_semantic_labels'][0][0] - 1  # 2

        ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(
            ref_label).cuda()
        query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(
            query_label[0, :, :]).cuda()
        #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
        #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

        # ref_img = ref_img*ref_label
        ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
        query_label_var, ref_label_var = Variable(query_label), Variable(
            ref_label)

        ref_img_var = torch.unsqueeze(ref_img_var, dim=0)  # [1, 3, 457, 500]
        ref_label_var = torch.unsqueeze(ref_label_var,
                                        dim=1)  # [1, 1, 457, 500]
        query_img_var = torch.unsqueeze(query_img_var,
                                        dim=0)  # [1, 3, 375, 500]
        query_label_var = torch.unsqueeze(query_label_var,
                                          dim=0)  # [1, 375, 500]

        samples = torch.cat([ref_img_var, query_img_var], 0)
        pred = model(samples, ref_label_var)
        w, h = query_label.size()
        pred = F.upsample(pred, size=(w, h), mode='bilinear')  #[2, 416, 416]
        pred = F.softmax(pred, dim=1).squeeze()
        values, pred = torch.max(pred, dim=0)
        #print(pred.shape)
        pred = pred.data.cpu().numpy().astype(np.int32)  # (333, 500)
        #print(pred.shape)
        org_img = get_org_img(
            query_img.squeeze().cpu().data.numpy())  # 查询集的图片(375, 500, 3)
        #print(org_img.shape)
        img = mask_to_img(pred, org_img)  # (375, 500, 3)mask和原图加权后的彩色图片
        cv2.imwrite('save_bins/que_pred/query_set_1_%d.png' % (count), img)

        query_label = query_label.cpu().numpy().astype(np.int32)  # (333, 500)
        class_ind = int(deploy_info['first_semantic_labels'][0][0]
                        ) - 1  # because class indices from 1 in data layer,0
        scorer.update(pred, query_label, class_ind + 1)
        tp, tn, fp, fn = measure(query_label, pred)
        # iou_img = tp/float(max(tn+fp+fn,1))
        tp_list[class_ind] += tp
        fp_list[class_ind] += fp
        fn_list[class_ind] += fn
        # max in case both pred and label are zero
        iou_list = [
            tp_list[ic] /
            float(max(tp_list[ic] + fp_list[ic] + fn_list[ic], 1))
            for ic in range(num_classes)
        ]

        tmp_pred = pred
        tmp_pred[tmp_pred > 0.5] = class_ind + 1
        tmp_gt_label = query_label
        tmp_gt_label[tmp_gt_label > 0.5] = class_ind + 1

        hist += Metrics.fast_hist(tmp_pred, query_label, 21)

    print("-------------GROUP %d-------------" % (group))
    print(iou_list)
    class_indexes = range(group * 5, (group + 1) * 5)
    print('Mean:', np.mean(np.take(iou_list, class_indexes)))
    '''
    for group in range(2):
        datalayer = SSDatalayer(group+1)
        restore(args, model, group+1)

        for count in tqdm(range(1000)):
            dat = datalayer.dequeue()
            ref_img = dat['second_img'][0]#(3, 457, 500)
            query_img = dat['first_img'][0]#(3, 375, 500)
            query_label = dat['second_label'][0]#(1, 375, 500)
            ref_label = dat['first_label'][0]#(1, 457, 500)
            # query_img = dat['second_img'][0]
            # ref_img = dat['first_img'][0]
            # ref_label = dat['second_label'][0]
            # query_label = dat['first_label'][0]
            deploy_info = dat['deploy_info']
            semantic_label = deploy_info['first_semantic_labels'][0][0] - 1#2

            ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(ref_label).cuda()
            query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(query_label[0,:,:]).cuda()
            #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
            #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

            # ref_img = ref_img*ref_label
            ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
            query_label_var, ref_label_var = Variable(query_label), Variable(ref_label)

            ref_img_var = torch.unsqueeze(ref_img_var,dim=0)#[1, 3, 457, 500]
            ref_label_var = torch.unsqueeze(ref_label_var, dim=1)#[1, 1, 457, 500]
            query_img_var = torch.unsqueeze(query_img_var, dim=0)#[1, 3, 375, 500]
            query_label_var = torch.unsqueeze(query_label_var, dim=0)#[1, 375, 500]

            logits  = model(query_img_var, ref_img_var, ref_label_var,ref_label_var)

            # w, h = query_label.size()
            # outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear')
            # out_side = F.softmax(outB_side, dim=1).squeeze()
            # values, pred = torch.max(out_side, dim=0)
            values, pred = model.get_pred(logits, query_img_var)#values[2, 333, 500]
            pred = pred.data.cpu().numpy().astype(np.int32)#(333, 500)

            query_label = query_label.cpu().numpy().astype(np.int32)#(333, 500)
            class_ind = int(deploy_info['first_semantic_labels'][0][0])-1 # because class indices from 1 in data layer,0
            scorer.update(pred, query_label, class_ind+1)
            tp, tn, fp, fn = measure(query_label, pred)
            # iou_img = tp/float(max(tn+fp+fn,1))
            tp_list[class_ind] += tp
            fp_list[class_ind] += fp
            fn_list[class_ind] += fn
            # max in case both pred and label are zero
            iou_list = [tp_list[ic] /
                        float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1))
                        for ic in range(num_classes)]


            tmp_pred = pred
            tmp_pred[tmp_pred>0.5] = class_ind+1
            tmp_gt_label = query_label
            tmp_gt_label[tmp_gt_label>0.5] = class_ind+1

            hist += Metrics.fast_hist(tmp_pred, query_label, 21)


        print("-------------GROUP %d-------------"%(group))
        print(iou_list)
        class_indexes = range(group*5, (group+1)*5)
        print('Mean:', np.mean(np.take(iou_list, class_indexes)))

    print('BMVC IOU', np.mean(np.take(iou_list, range(0,20))))

    miou = Metrics.get_voc_iou(hist)
    print('IOU:', miou, np.mean(miou))
    '''

    binary_hist = np.array((hist[0, 0], hist[0, 1:].sum(), hist[1:, 0].sum(),
                            hist[1:, 1:].sum())).reshape((2, 2))
    bin_iu = np.diag(binary_hist) / (binary_hist.sum(1) + binary_hist.sum(0) -
                                     np.diag(binary_hist))
    print('Bin_iu:', bin_iu)

    scores = scorer.score()
    for k in scores.keys():
        print(k, np.mean(scores[k]), scores[k])
Beispiel #8
0
def main():
    args = get_arguments()

    seed = args.random_seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    input_size = (1024, 512)
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    if args.num_classes == 13:
        name_classes = np.asarray([
            "road", "sidewalk", "building", "light", "sign", "vegetation",
            "sky", "person", "rider", "car", "bus", "motorcycle", "bicycle"
        ])
    elif args.num_classes == 18:
        name_classes = np.asarray([
            "road", "sidewalk", "building", "wall", "fence", "pole", "light",
            "sign", "vegetation", "sky", "person", "rider", "car", "truck",
            "bus", "train", "motorcycle", "bicycle"
        ])
    else:
        NotImplementedError("Unavailable number of classes")

    # Create the model and start the evaluation process
    model = DeeplabMulti(num_classes=args.num_classes)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        saved_state_dict = torch.load('./snapshots/' + args.dir_name + '/' +
                                      str((files + 1) * args.save_pred_every) +
                                      '.pth')
        # saved_state_dict = torch.load('./snapshots/' + '30000.pth')
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
        model = model.to(device)

        model.eval()
        if args.gta5:
            gta5_loader = torch.utils.data.DataLoader(
                GTA5DataSet(args.data_dir_gta5,
                            args.data_list_gta5,
                            crop_size=input_size,
                            ignore_label=args.ignore_label,
                            set=args.set,
                            num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(gta5_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (GTA5): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.synthia:
            synthia_loader = torch.utils.data.DataLoader(
                SYNTHIADataSet(args.data_dir_synthia,
                               args.data_list_synthia,
                               crop_size=input_size,
                               ignore_label=args.ignore_label,
                               set=args.set,
                               num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(synthia_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (SYNTHIA): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.cityscapes:
            cityscapes_loader = torch.utils.data.DataLoader(
                cityscapesDataSet(args.data_dir_cityscapes,
                                  args.data_list_cityscapes,
                                  crop_size=input_size,
                                  ignore_label=args.ignore_label,
                                  set=args.set,
                                  num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(cityscapes_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (CityScapes): ' +
                  str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)

        if args.idd:
            idd_loader = torch.utils.data.DataLoader(
                IDDDataSet(args.data_dir_idd,
                           args.data_list_idd,
                           crop_size=input_size,
                           ignore_label=args.ignore_label,
                           set=args.set,
                           num_classes=args.num_classes),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

            hist = np.zeros((args.num_classes, args.num_classes))
            for i, data in enumerate(idd_loader):
                images_val, labels, _ = data
                images_val, labels = images_val.to(device), labels.to(device)
                _, pred = model(images_val)
                pred = interp(pred)
                _, pred = pred.max(dim=1)

                labels = labels.cpu().numpy()
                pred = pred.cpu().detach().numpy()

                hist += fast_hist(labels.flatten(), pred.flatten(),
                                  args.num_classes)
            mIoUs = per_class_iu(hist)
            if args.mIoUs_per_class:
                for ind_class in range(args.num_classes):
                    print('==>' + name_classes[ind_class] + ':\t' +
                          str(round(mIoUs[ind_class] * 100, 2)))
            print('===> mIoU (IDD): ' + str(round(np.nanmean(mIoUs) * 100, 2)))
            print('=' * 50)
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D parameters
        restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D)
        model_D.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    """
    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        adv_loss_value = 0
        d_loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate(optimizer_D, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            """
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False
            """

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # size (1, 19, 720, 1280)
            # pdb.set_trace()

            # feature = nn.Softmax(dim=1)(pred1)
            # softmax_out = nn.Softmax(dim=1)(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            _, batch = targetloader_iter.__next__()

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=False)

            images, _, _ = batch
            images = images.to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred_target1, pred_target2 = model(images)

            # pred_target1, 2 == [1, 19, 91, 161]
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            # pred_target1, 2 == [1, 19, 720, 1280]
            # pdb.set_trace()

            # feature_target = nn.Softmax(dim=1)(pred_target1)
            # softmax_out_target = nn.Softmax(dim=1)(pred_target2)

            # features = torch.cat((pred1, pred_target1), dim=0)
            # outputs = torch.cat((pred2, pred_target2), dim=0)
            # features.size() == [2, 19, 720, 1280]
            # softmax_out.size() == [2, 19, 720, 1280]
            # pdb.set_trace()
            # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None)
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')
            dc_source = torch.FloatTensor(
                D_out_target.size()).fill_(0).to(device)
            # pdb.set_trace()
            adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source)
            adv_loss = adv_loss / args.iter_size
            adv_loss = args.lambda_adv * adv_loss
            # pdb.set_trace()
            # classifier_loss = nn.BCEWithLogitsLoss()(pred2,
            #        torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda())
            # pdb.set_trace()
            adv_loss.backward()
            adv_loss_value += adv_loss.item()
            # optimizer_D.step()
            #TODO: normalize loss?

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=True)

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            D_out = CDAN([F.softmax(pred1), F.softmax(pred2)],
                         model_D,
                         cdan_implement='concat')

            dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device)
            # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None)
            d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')

            dc_target = torch.FloatTensor(
                D_out_target.size()).fill_(1).to(device)
            d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            continue

        optimizer.step()
        optimizer_D.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'generator_loss': adv_loss_value,
            'discriminator_loss': d_loss_value,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    adv_loss_value, d_loss_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)
            save_path = args.snapshot_dir + '/eval_{}'.format(i_iter)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            # evaluate(args, save_path, args.snapshot_dir, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
                i_iter)
            pdb.set_trace()
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    print(model)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(BDDDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(960, 540),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False,
                                            set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)  # 960 540

    interp = nn.Upsample(size=(720, 1280), mode='bilinear', align_corners=True)

    if args.save_confidence:
        select = open('list.txt', 'w')
        c_list = []

    for index, batch in enumerate(testloader):
        if index % 10 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        output = model(image)

        if args.save_confidence:
            confidence = get_confidence(output)
            confidence = confidence.cpu().item()
            c_list.append([confidence, name])
            name = name[0].split('/')[-1]
            save_path = '%s/%s_c.txt' % (args.save, name.split('.')[0])
            record = open(save_path, 'w')
            record.write('%.5f' % confidence)
            record.close()
        else:
            name = name[0].split('/')[-1]

        output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        output.save('%s/%s' % (args.save, name[:-4] + '.png'))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    def takeFirst(elem):
        return elem[0]

    if args.save_confidence:
        c_list.sort(key=takeFirst, reverse=True)
        length = len(c_list)
        for i in range(length // 3):
            print(c_list[i][0])
            print(c_list[i][1])
            select.write(c_list[i][1][0])
            select.write('\n')
        select.close()

    print(args.save)
Beispiel #11
0
def main():
    """Create the model and start the training."""
    setup_seed(666)
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            if i_parts[1]=='layer4' and i_parts[2]=='2':
                i_parts[1] = 'layer5'
                i_parts[2] = '0'
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            else:
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    model.load_state_dict(new_params)
    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(device) if i<1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in range(2)])
    
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)

        loss_adv = 0
        D_out = model_D[0](feat_target)
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        D_out = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv*0.01
        loss_adv.backward()
        
        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        loss_D_source = 0
        D_out_source = model_D[0](feat_source.detach())
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        D_out_source = model_D[1](F.softmax(pred_source.detach(),dim=1))
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        loss_D_source.backward()

        # train with target
        loss_D_target = 0
        D_out_target = model_D[0](feat_target.detach())
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        D_out_target = model_D[1](F.softmax(pred_target.detach(),dim=1))
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        loss_D_target.backward()
        
        optimizer_D.step()

        if i_iter % 10 == 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D_s = {4:.3f}, loss_D_t = {5:.3f}'.format(
            i_iter, args.num_steps, loss_seg.item(), loss_adv.item(), loss_D_source.item(), loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
Beispiel #12
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #     model_D1.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D1.pth'))
    #     model_D2.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D2.pth'))

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Load VGG
    #vgg19 = torchvision.models.vgg19(pretrained=True)
    #vgg19.to(device)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(0, args.num_steps):

        loss_seg_value = 0
        #         loss_seg_local_value = 0
        loss_adv_target_value = 0
        #         loss_adv_local_value = 0
        loss_D_value = 0
        #         loss_D_local_value = 0
        loss_local_match_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred_s1, pred_s2, _, _ = model(images)
            #f_s2 = normalize(f_s2)
            pred_s1, pred_s2 = interp(pred_s1), interp(pred_s2)

            loss_seg = args.lambda_seg * seg_loss(pred_s1, labels) + seg_loss(
                pred_s2, labels)
            del labels

            # proper normalization
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target
            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_t1, pred_t2, _, _ = model(images)
            #f_t2 = normalize(f_t2)
            pred_t1, pred_t2 = interp_target(pred_t1), interp_target(pred_t2)
            del images

            D_out_1 = model_D1(F.softmax(pred_t1, dim=1))
            D_out_2 = model_D2(F.softmax(pred_t2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out_1,
                torch.FloatTensor(
                    D_out_1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = bce_loss(
                D_out_2,
                torch.FloatTensor(
                    D_out_2.data.size()).fill_(source_label).to(device))
            loss_adv_target_value += (
                args.lambda_adv_target1 * loss_adv_target1 +
                args.lambda_adv_target *
                loss_adv_target2).item() / args.iter_size

            loss = loss_seg + args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2

            del D_out_1, D_out_2

            #             #< Local patch part>#
            #             corres_id2 = get_correspondance(f_s2, f_t2, pred_s2, pred_t2)

            #             #loss_local1 = local_feature_loss(corres_id1, f_s1, f_t1, model, seg_loss)
            #             loss_local2 = local_feature_loss(corres_id2, labels, f_t2, model, seg_loss)
            #             loss_local = args.lambda_match_target2 * loss_local2 #+args.lambda_match_target1 * loss_local1
            #             loss += loss_local
            #             if corres_id2.nelement() > 0:
            #                 loss_local_match_value += loss_local.item()/ args.iter_size

            loss /= args.iter_size
            loss.backward()

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_s1)), model_D2(
                F.softmax(pred_s2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

            # train with target
            pred_t1, pred_t2 = pred_t1.detach(), pred_t2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_t1)), model_D2(
                F.softmax(pred_t2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 1000 == 0:
            val_dir = '../dataset/Cityscapes/leftImg8bit_trainvaltest/'
            val_list = './dataset/cityscapes_list/val.txt'
            save_dir = './results/tmp'
            gt_dir = '../dataset/Cityscapes/gtFine_trainvaltest/gtFine/val'
            evaluate_cityscapes.test_model(model, device, val_dir, val_list,
                                           save_dir)
            mIoU = compute_iou.mIoUforTest(gt_dir, save_dir)

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                #'loss_seg_local': loss_seg_local_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_local_match': loss_local_match_value,
                'loss_D': loss_D_value,
                'mIoU': mIoU
                #'loss_D_local': loss_D_local_value
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f}, loss_local_match = {5:.3f}, mIoU = {6:3f} '
        #'loss_seg_local = {5:.3f} loss_adv_local = {6:.3f}, loss_D_local = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_local_match_value, mIoU)
                    #loss_seg_local_value, loss_adv_local_value, loss_D_local_value)
        )

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    nyu_nyu_dict = {11:255, 13:255, 15:255, 17:255, 19:255, 20:255, 21: 255, 23: 255, 
            24:255, 25:255, 26:255, 27:255, 28:255, 29:255, 31:255, 32:255, 33:255}
    nyu_nyu_map = lambda x: nyu_nyu_dict.get(x+1,x)
    nyu_nyu_map = np.vectorize(nyu_nyu_map)
    args.nyu_nyu_map = nyu_nyu_map
    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    metrics = StreamSegMetrics(args.num_classes)
    metrics_remap = StreamSegMetrics(args.num_classes)
    ignore_label = 255
    value_scale = 255 
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transforms.Compose([
	    # et.ExtResize( 512 ),
	    transforms.Crop([args.height+1, args.width+1], crop_type='center', padding=IMG_MEAN, ignore_label=ignore_label),
	    transforms.ToTensor(),
	    transforms.Normalize(mean=IMG_MEAN,
	    	    std=[1, 1, 1]),
	])
    val_dst = NYU(root=args.data_dir, opt=args,
			 split='val', transform=val_transform,
			 imWidth = args.width, imHeight = args.height, phase="TEST",
			 randomize = False)
    print("Dset Length {}".format(len(val_dst)))
    testloader = data.DataLoader(val_dst,
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(args.height+1, args.width+1), mode='bilinear', align_corners=True)
    metrics.reset()
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, targets, name = batch
        image = image.to(device)
        print(index)
        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()
        targets = targets.cpu().numpy()
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        preds = output[None,:,:]
        #input_ = image.cpu().numpy()[0].transpose(1,2,0) + np.array(IMG_MEAN)
        metrics.update(targets, preds)
        targets = args.nyu_nyu_map(targets)
        preds = args.nyu_nyu_map(preds)
        metrics_remap.update(targets,preds)
        #input_ = Image.fromarray(input_.astype(np.uint8))
        #output_col = colorize_mask(output)
        #output = Image.fromarray(output)
        
        #name = name[0].split('/')[-1]
        #input_.save('%s/%s' % (args.save, name))
        #output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
    print(metrics.get_results())
    print(metrics_remap.get_results())
def main():
    """Create the model and start the training."""
    or_nyu_dict = {
        0: 255,
        1: 16,
        2: 40,
        3: 39,
        4: 7,
        5: 14,
        6: 39,
        7: 12,
        8: 38,
        9: 40,
        10: 10,
        11: 6,
        12: 40,
        13: 39,
        14: 39,
        15: 40,
        16: 18,
        17: 40,
        18: 4,
        19: 40,
        20: 40,
        21: 5,
        22: 40,
        23: 40,
        24: 30,
        25: 36,
        26: 38,
        27: 40,
        28: 3,
        29: 40,
        30: 40,
        31: 9,
        32: 38,
        33: 40,
        34: 40,
        35: 40,
        36: 34,
        37: 37,
        38: 40,
        39: 40,
        40: 39,
        41: 8,
        42: 3,
        43: 1,
        44: 2,
        45: 22
    }
    or_nyu_map = lambda x: or_nyu_dict.get(x, x) - 1
    or_nyu_map = np.vectorize(or_nyu_map)

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    args.or_nyu_map = or_nyu_map
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from == "":
            saved_state_dict = None
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 40 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True
    if args.mode != "baseline" and args.mode != "baseline_tar":
        # init D
        model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
        model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

        model_D1.train()
        model_D1.to(device)

        model_D2.train()
        model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    scale_min = 0.5
    scale_max = 2.0
    rotate_min = -10
    rotate_max = 10
    ignore_label = 255
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    args.width = w
    args.height = h
    train_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.RandScale([scale_min, scale_max]),
        transforms.RandRotate([rotate_min, rotate_max],
                              padding=IMG_MEAN_RGB,
                              ignore_label=ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='rand',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        #et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
        #et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        #et.ExtRandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])

    val_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='center',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])
    if args.mode != "baseline_tar":
        src_train_dst = OpenRoomsSegmentation(root=args.data_dir,
                                              opt=args,
                                              split='train',
                                              transform=train_transform,
                                              imWidth=args.width,
                                              imHeight=args.height,
                                              remap_labels=args.or_nyu_map)
    else:
        src_train_dst = NYU_Labelled(root=args.data_dir_target,
                                     opt=args,
                                     split='train',
                                     transform=train_transform,
                                     imWidth=args.width,
                                     imHeight=args.height,
                                     phase="TRAIN",
                                     randomize=True)
    tar_train_dst = NYU(root=args.data_dir_target,
                        opt=args,
                        split='train',
                        transform=train_transform,
                        imWidth=args.width,
                        imHeight=args.height,
                        phase="TRAIN",
                        randomize=True,
                        mode=args.mode)
    tar_val_dst = NYU(root=args.data_dir,
                      opt=args,
                      split='val',
                      transform=val_transform,
                      imWidth=args.width,
                      imHeight=args.height,
                      phase="TRAIN",
                      randomize=False)
    trainloader = data.DataLoader(src_train_dst,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(tar_train_dst,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    if args.mode != "baseline" and args.mode != "baseline_tar":
        optimizer_D1 = optim.Adam(model_D1.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D1.zero_grad()

        optimizer_D2 = optim.Adam(model_D2.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1] + 1, input_size[0] + 1),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1] + 1,
                                      input_size_target[0] + 1),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_seg_value1_tar = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_seg_value2_tar = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.zero_grad()
            optimizer_D2.zero_grad()
            adjust_learning_rate_D(optimizer_D1, i_iter)
            adjust_learning_rate_D(optimizer_D2, i_iter)
        sample_src = None
        sample_tar = None
        sample_res_src = None
        sample_res_tar = None
        sample_gt_src = None
        sample_gt_tar = None
        for sub_i in range(args.iter_size):

            # train G
            if args.mode != "baseline" and args.mode != "baseline_tar":
                # don't accumulate grads in D
                for param in model_D1.parameters():
                    param.requires_grad = False

                for param in model_D2.parameters():
                    param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels, _ = batch
            sample_src = images.clone()
            sample_gt_src = labels.clone()

            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            sample_pred_src = pred2.detach().cpu()

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target
            try:
                _, batch = targetloader_iter.__next__()
            except:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, tar_labels, _, labelled = batch
            n_labelled = labelled.sum().detach().item()
            batch_size = images.shape[0]
            sample_tar = images.clone()
            sample_gt_tar = tar_labels.clone()
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            #print("N_labelled {}".format(n_labelled))
            if args.mode == "sda" and n_labelled != 0:
                labelled = labelled.to(device) == 1
                tar_labels = tar_labels.to(device)
                loss_seg1_tar = seg_loss(pred_target1[labelled],
                                         tar_labels[labelled])
                loss_seg2_tar = seg_loss(pred_target2[labelled],
                                         tar_labels[labelled])
                loss_tar_labelled = loss_seg2_tar + args.lambda_seg * loss_seg1_tar
                loss_tar_labelled = loss_tar_labelled / args.iter_size
                loss_seg_value1_tar += loss_seg1_tar.item() / args.iter_size
                loss_seg_value2_tar += loss_seg2_tar.item() / args.iter_size
            else:
                loss_tar_labelled = torch.zeros(
                    1, requires_grad=True).float().to(device)
            # proper normalization
            sample_pred_tar = pred_target2.detach().cpu()
            if args.mode != "baseline" and args.mode != "baseline_tar":
                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size + loss_tar_labelled
                #loss = loss_tar_labelled
                loss.backward()
                loss_adv_target_value1 += loss_adv_target1.item(
                ) / args.iter_size
                loss_adv_target_value2 += loss_adv_target2.item(
                ) / args.iter_size
                # train D

                # bring back requires_grad
                for param in model_D1.parameters():
                    param.requires_grad = True

                for param in model_D2.parameters():
                    param.requires_grad = True

                # train with source
                pred1 = pred1.detach()
                pred2 = pred2.detach()

                D_out1 = model_D1(F.softmax(pred1))
                D_out2 = model_D2(F.softmax(pred2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

                # train with target
                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()

                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

        optimizer.step()
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.step()
            optimizer_D2.step()
        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
                'loss_seg1_tar': loss_seg_value1_tar,
                'loss_seg2_tar': loss_seg_value2_tar,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)
            if i_iter % 1000 == 0:
                img = sample_src.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy(
                    np.array(IMG_MEAN_RGB).reshape(1, 3, 1, 1)).float()
                img = img.type(torch.uint8)
                writer.add_images("Src/Images", img, i_iter)
                label = tar_train_dst.decode_target(sample_gt_src).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Labels", label, i_iter)
                preds = sample_pred_src.permute(0, 2, 3, 1).cpu().numpy()
                preds = np.asarray(np.argmax(preds, axis=3), dtype=np.uint8)
                preds = tar_train_dst.decode_target(preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Preds", preds, i_iter)

                tar_img = sample_tar.cpu()[:,
                                           [2, 1, 0], :, :] + torch.from_numpy(
                                               np.array(IMG_MEAN_RGB).reshape(
                                                   1, 3, 1, 1)).float()
                tar_img = tar_img.type(torch.uint8)
                writer.add_images("Tar/Images", tar_img, i_iter)
                tar_label = tar_train_dst.decode_target(
                    sample_gt_tar).transpose(0, 3, 1, 2)
                writer.add_images("Tar/Labels", tar_label, i_iter)
                tar_preds = sample_pred_tar.permute(0, 2, 3, 1).cpu().numpy()
                tar_preds = np.asarray(np.argmax(tar_preds, axis=3),
                                       dtype=np.uint8)
                tar_preds = tar_train_dst.decode_target(tar_preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Tar/Preds", tar_preds, i_iter)
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_seg1_tar={8:.3f} loss_seg2_tar={9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_seg_value1_tar,
                    loss_seg_value2_tar))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'OR_' + str(args.num_steps_stop) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
Beispiel #15
0
def main():
    '''
        Create the model and start the training.
    '''

    # Device 설정
    device = torch.device("cuda" if not args.cpu else "cpu")

    # Source와 Target 모두 1280 * 720으로 resizing
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)
    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # 모델 생성
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :                        # 미리 학습된 weight를 다운로드
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:                                                       # pth 파일을 직접 설정할 경우
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # Discriminator 생성
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _, domainess = batch
            adw = torch.sqrt(1-domainess).float()
            adw.requires_grad = False
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))


            loss_adv_target1 = adw*bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = adw*bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
def main():
    setup_seed(666)
    device = torch.device("cuda")
    save_path = args.save
    save_pseudo_label_path = osp.join(
        save_path,
        'pseudo_label')  # in 'save_path'. Save labelIDs, not trainIDs.
    save_stats_path = osp.join(save_path, 'stats')  # in 'save_path'
    save_lst_path = osp.join(save_path, 'list')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(save_pseudo_label_path):
        os.makedirs(save_pseudo_label_path)
    if not os.path.exists(save_stats_path):
        os.makedirs(save_stats_path)
    if not os.path.exists(save_lst_path):
        os.makedirs(save_lst_path)

    cudnn.enabled = True
    cudnn.benchmark = True

    logger = util.set_logger(args.save, args.log_file, args.debug)
    logger.info('start with arguments %s', args)

    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)
    model.train()
    model.to(device)

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([
        FCDiscriminator(num_classes=num_class_list[i]).train().to(device)
        if i < 1 else OutspaceDiscriminator(
            num_classes=num_class_list[i]).train().to(device) for i in range(2)
    ])
    saved_state_dict_D = torch.load(args.restore_from_D)
    model_D.load_state_dict(saved_state_dict_D)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    image_src_list, _, src_num = parse_split_list(args.data_src_list)
    image_tgt_list, image_name_tgt_list, tgt_num = parse_split_list(
        args.data_tgt_train_list)
    # portions
    tgt_portion = args.init_tgt_port

    # training crop size
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    bce_loss1 = torch.nn.MSELoss()
    bce_loss2 = torch.nn.MSELoss(reduce=False, reduction='none')
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)
    round_idx = 3
    save_round_eval_path = osp.join(args.save, str(round_idx))
    save_pseudo_label_color_path = osp.join(save_round_eval_path,
                                            'pseudo_label_color')
    if not os.path.exists(save_round_eval_path):
        os.makedirs(save_round_eval_path)
    if not os.path.exists(save_pseudo_label_color_path):
        os.makedirs(save_pseudo_label_color_path)
    ########## pseudo-label generation
    # evaluation & save confidence vectors
    test(model, model_D, device, save_round_eval_path, round_idx, 500, args,
         logger)
    conf_dict, pred_cls_num, save_prob_path, save_pred_path = val(
        model, model_D, device, save_round_eval_path, round_idx, tgt_num, args,
        logger)
    # class-balanced thresholds
    cls_thresh = kc_parameters(conf_dict, pred_cls_num, tgt_portion, round_idx,
                               save_stats_path, args, logger)
    # pseudo-label maps generation
    label_selection(cls_thresh, tgt_num, image_name_tgt_list, round_idx,
                    save_prob_path, save_pred_path, save_pseudo_label_path,
                    save_pseudo_label_color_path, save_round_eval_path, args,
                    logger)
    src_train_lst, tgt_train_lst, src_num_sel = savelst_SrcTgt(
        image_tgt_list, image_name_tgt_list, image_src_list, save_lst_path,
        save_pseudo_label_path, src_num, tgt_num, args)
    ########### model retraining
    # dataset
    srctrainset = SrcSTDataSet(args.data_src_dir,
                               src_train_lst,
                               max_iters=args.num_steps * args.batch_size,
                               crop_size=input_size,
                               scale=False,
                               mirror=False,
                               mean=IMG_MEAN)
    tgttrainset = TgtSTDataSet(args.data_tgt_dir,
                               tgt_train_lst,
                               pseudo_root=save_pseudo_label_path,
                               max_iters=args.num_steps * args.batch_size,
                               crop_size=input_size_target,
                               scale=False,
                               mirror=False,
                               mean=IMG_MEAN,
                               set='train')
    trainloader = torch.utils.data.DataLoader(srctrainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=8,
                                              pin_memory=True)
    trainloader_iter = enumerate(trainloader)
    targetloader = torch.utils.data.DataLoader(tgttrainset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    logger.info(
        '###### Start model retraining dataset in round {}! ######'.format(
            round_idx))

    start = timeit.default_timer()
    # start training
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        lamb = 1
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)
        # atten_target = F.interpolate(atten_target, size=(16, 32), mode='bilinear', align_corners=True)

        loss_seg_tgt = seg_loss(pred_target, labels) * lamb

        D_out1 = model_D[0](feat_target)
        loss_adv1 = bce_loss1(
            D_out1,
            torch.FloatTensor(
                D_out1.data.size()).fill_(source_label).to(device))
        D_out2 = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv2 = bce_loss2(
            D_out2,
            torch.FloatTensor(
                D_out2.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv1 * 0.01 + loss_adv2.mean() * 0.01
        loss = loss_seg_tgt + loss_adv
        loss.backward()

        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        D_out_source1 = model_D[0](feat_source.detach())
        loss_D_source1 = bce_loss1(
            D_out_source1,
            torch.FloatTensor(
                D_out_source1.data.size()).fill_(source_label).to(device))
        D_out_source2 = model_D[1](F.softmax(pred_source.detach(), dim=1))
        loss_D_source2 = bce_loss1(
            D_out_source2,
            torch.FloatTensor(
                D_out_source2.data.size()).fill_(source_label).to(device))
        loss_D_source = loss_D_source1 + loss_D_source2
        loss_D_source.backward()

        # train with target
        D_out_target1 = model_D[0](feat_target.detach())
        loss_D_target1 = bce_loss1(
            D_out_target1,
            torch.FloatTensor(
                D_out_target1.data.size()).fill_(target_label).to(device))
        D_out_target2 = model_D[1](F.softmax(pred_target.detach(), dim=1))
        weight_target = bce_loss2(
            D_out_target2,
            torch.FloatTensor(
                D_out_target2.data.size()).fill_(target_label).to(device))
        loss_D_target2 = weight_target.mean()
        loss_D_target = loss_D_target1 + loss_D_target2
        loss_D_target.backward()

        optimizer_D.step()

        if i_iter % 10 == 0:
            print(
                'iter={0:8d}/{1:8d}, seg={2:.3f} seg_tgt={3:.3f} adv={4:.3f} adv1={5:.3f} adv2={6:.3f} src1={7:.3f} src2={8:.3f} tgt1={9:.3f} tgt2={10:.3f} D1={11:.3f} D2={12:.3f}'
                .format(i_iter, args.num_steps, loss_seg.item(),
                        loss_seg_tgt.item(), loss_adv.item(), loss_adv1.item(),
                        loss_adv2.mean().item(), loss_D_source1.item(),
                        loss_D_source2.item(), loss_D_target1.item(),
                        loss_D_target2.item(), loss_D_source.item(),
                        loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            test(model, model_D, device, save_round_eval_path, round_idx, 500,
                 args, logger)
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    logger.info(
        '###### Finish model retraining dataset in round {}! Time cost: {:.2f} seconds. ######'
        .format(round_idx, end - start))
    # test self-trained model in target domain test set
    test(model, model_D, device, save_round_eval_path, round_idx, 500, args,
         logger)
def main():
    seed = 1338
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        if SOURCE_ONLY:
            saved_state_dict = torch.load('./snapshots/source_only/GTA5_' +
                                          str((files + 1) *
                                              args.save_pred_every) + '.pth')
        else:
            if args.level == 'single-level':
                saved_state_dict = torch.load(
                    './snapshots/single_level/GTA5_' +
                    str((files + 1) * args.save_pred_every) + '.pth')
            elif args.level == 'multi-level':
                saved_state_dict = torch.load('./snapshots/multi_level/GTA5_' +
                                              str((files + 1) *
                                                  args.save_pred_every) +
                                              '.pth')
            else:
                raise NotImplementedError(
                    'level choice {} is not implemented'.format(args.level))
        ### for running different versions of pytorch
        model_dict = model.state_dict()
        saved_state_dict = {
            k: v
            for k, v in saved_state_dict.items() if k in model_dict
        }
        model_dict.update(saved_state_dict)
        ###
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if not args.cpu else "cpu")
        model = model.to(device)
        if args.multi_gpu:
            model = nn.DataParallel(model)

        model.eval()

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch
            image = image.to(device)

            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
                output = model(image)
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            if SOURCE_ONLY:
                if not os.path.exists(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every))):
                    os.makedirs(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every)))
                output.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every), name))
                output_col.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every),
                        name.split('.')[0] + '_color.png'))
            else:
                if args.level == 'single-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                elif args.level == 'multi-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                else:
                    raise NotImplementedError(
                        'level choice {} is not implemented'.format(
                            args.level))