コード例 #1
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        #model.load_state_dict(new_params)
        
    if CONTINUE_FLAG==1:
        model.load_state_dict(saved_state_dict)    
    
    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D_a = FCDiscriminator(num_classes=256)  # need to check
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)
    
    if CONTINUE_FLAG==1:
        d1_saved_state_dict = torch.load(D1_RESTORE_FROM)
        d2_saved_state_dict = torch.load(D2_RESTORE_FROM)
        model_D1.load_state_dict(d1_saved_state_dict)
        model_D2.load_state_dict(d2_saved_state_dict)
        
    model_D_a.train()
    model_D_a.cuda(args.gpu)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        synthiaDataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D_a = optim.Adam(model_D_a.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D_a.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1
    mIoUs = []
    for i_iter in range(continue_start_iter+1,args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D_a.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D_a, i_iter)
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)



        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D_a.parameters():
                param.requires_grad = False

            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_a, pred1, pred2 = model(images)
            pred1=nn.functional.interpolate(pred1,size=(input_size[1], input_size[0]), mode='bilinear',align_corners=True)
            pred2 = nn.functional.interpolate(pred2, size=(input_size[1], input_size[0]), mode='bilinear',
                                              align_corners=True)
            # pred1 = interp(pred1)
            # pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            lambda_wtight = (80000 - i_iter) / 80000
            if lambda_wtight > 0:
                pred_target_a, _, _ = model(images)
                D_out_a = model_D_a(pred_target_a)
                loss_adv_target_a = bce_loss(D_out_a,
                                             Variable(torch.FloatTensor(D_out_a.data.size()).fill_(source_label)).cuda(
                                                 args.gpu))
                loss_adv_target_a=LAMBDA_ADV_TARGET_A *loss_adv_target_a
                loss_adv_target_a = loss_adv_target_a / args.iter_size
                loss_adv_target_a.backward()

            _, pred_target1, pred_target2 = model(images)

            pred_target1 = nn.functional.interpolate(pred_target1,size=(input_size_target[1], input_size_target[0]), mode='bilinear',align_corners=True)
            pred_target2 = nn.functional.interpolate(pred_target2, size=(input_size_target[1], input_size_target[0]),
                                                     mode='bilinear', align_corners=True)

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D_a.parameters():
                param.requires_grad = True

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            if lambda_wtight > 0:
                pred_a=pred_a.detach()
                D_out_a = model_D_a(pred_a)
                loss_D_a = bce_loss(D_out_a,
                                    Variable(torch.FloatTensor(D_out_a.data.size()).fill_(source_label)).cuda(args.gpu))
                loss_D_a = loss_D_a / args.iter_size / 2
                loss_D_a.backward()

            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1,dim=1))
            D_out2 = model_D2(F.softmax(pred2,dim=1))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            if lambda_wtight > 0:
                pred_target_a=pred_target_a.detach()
                D_out_a = model_D_a(pred_target_a)
                loss_D_a = bce_loss(D_out_a,
                                    Variable(torch.FloatTensor(D_out_a.data.size()).fill_(target_label)).cuda(args.gpu))
                loss_D_a = loss_D_a / args.iter_size / 2
                loss_D_a.backward()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D_a.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D2.pth'))
            show_val(model.state_dict(), LAMBDA_ADV_TARGET_A ,i_iter)
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            mIoU=show_val(model.state_dict(), LAMBDA_ADV_TARGET_A ,i_iter)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            for miou in mIoUs:
                print(miou)
コード例 #2
0
ファイル: SEAN_Synthia.py プロジェクト: xiaoyan-Lu/SEANet
def main():
    """Create the model and start the training."""
    args = get_arguments()
    if os.path.exists(args.snapshot_dir) == False:
        os.mkdir(args.snapshot_dir)
    f = open(args.snapshot_dir + 'Synthia2Cityscapes_log.txt', 'w')

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)
    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    # Create network
    student_net = SEANet(num_classes=args.num_classes)
    teacher_net = SEANet(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)

    new_params = student_net.state_dict().copy()
    for i, j in zip(saved_state_dict, new_params):
        if (i[0] != 'f') & (i[0] != 's') & (i[0] != 'u'):
            new_params[j] = saved_state_dict[i]

    student_net.load_state_dict(new_params)
    teacher_net.load_state_dict(new_params)

    for name, param in teacher_net.named_parameters():
        param.requires_grad = False

    teacher_net = teacher_net.cuda()
    student_net = student_net.cuda()

    src_loader = data.DataLoader(synthiaDataSet(args.data_dir_source,
                                                args.data_list_source,
                                                crop_size=input_size,
                                                scale=False,
                                                mirror=False,
                                                mean=IMG_MEAN),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_loader = data.DataLoader(cityscapes16DataSet(args.data_dir_target,
                                                     args.data_list_target,
                                                     max_iters=9400,
                                                     crop_size=input_size,
                                                     scale=False,
                                                     mirror=False,
                                                     mean=IMG_MEAN,
                                                     set='val'),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    val_loader = data.DataLoader(cityscapes16DataSet(args.data_dir_target,
                                                     args.data_list_target,
                                                     max_iters=None,
                                                     crop_size=input_size,
                                                     scale=False,
                                                     mirror=False,
                                                     mean=IMG_MEAN,
                                                     set='val'),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    num_batches = min(len(src_loader), len(tgt_loader))

    optimizer = optim.Adam(student_net.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    optimizer.zero_grad()

    student_params = list(student_net.parameters())
    teacher_params = list(teacher_net.parameters())

    teacher_optimizer = WeightEMA(
        teacher_params,
        student_params,
        alpha=args.teacher_alpha,
    )

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    n_class = args.num_classes
    num_steps = args.num_epoch * num_batches
    loss_hist = np.zeros((num_steps, 5))
    index_i = -1
    OA_hist = 0.2
    aug_loss = torch.nn.MSELoss()

    for epoch in range(args.num_epoch):
        if epoch == 6:
            return
        for batch_index, (src_data,
                          tgt_data) in enumerate(zip(src_loader, tgt_loader)):
            index_i += 1

            tem_time = time.time()
            student_net.train()
            optimizer.zero_grad()

            # train with source
            images, src_label, _, im_name = src_data
            images = images.cuda()
            src_label = src_label.cuda()
            _, src_output = student_net(images)
            src_output = interp(src_output)
            # Segmentation Loss
            cls_loss_value = loss_calc(src_output, src_label)
            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            metrics_batch = []
            for lt, lp in zip(lbl_true, lbl_pred):
                _, _, mean_iu, _ = label_accuracy_score(
                    lt, lp, n_class=args.num_classes)
                metrics_batch.append(mean_iu)
            miu = np.mean(metrics_batch, axis=0)

            # train with target
            images, label_target, _, im_name = tgt_data
            images = images.cuda()
            label_target = label_target.cuda()
            tgt_t_input = images + torch.randn(
                images.size()).cuda() * args.noise
            tgt_s_input = images + torch.randn(
                images.size()).cuda() * args.noise

            _, tgt_s_output = student_net(tgt_s_input)
            t_confidence, tgt_t_output = teacher_net(tgt_t_input)

            t_confidence = t_confidence.squeeze()

            # self-ensembling Loss
            tgt_t_predicts = F.softmax(tgt_t_output,
                                       dim=1).transpose(1, 2).transpose(2, 3)
            tgt_s_predicts = F.softmax(tgt_s_output,
                                       dim=1).transpose(1, 2).transpose(2, 3)

            mask = t_confidence > args.attention_threshold
            mask = mask.view(-1)
            num_pixel = mask.shape[0]

            mask_rate = torch.sum(mask).float() / num_pixel

            tgt_s_predicts = tgt_s_predicts.contiguous().view(-1, n_class)
            tgt_s_predicts = tgt_s_predicts[mask]
            tgt_t_predicts = tgt_t_predicts.contiguous().view(-1, n_class)
            tgt_t_predicts = tgt_t_predicts[mask]
            aug_loss_value = aug_loss(tgt_s_predicts, tgt_t_predicts)
            aug_loss_value = args.st_weight * aug_loss_value

            # TOTAL LOSS
            if mask_rate == 0.:
                aug_loss_value = torch.tensor(0.).cuda()

            total_loss = cls_loss_value + aug_loss_value

            total_loss.backward()
            loss_hist[index_i, 0] = total_loss.item()
            loss_hist[index_i, 1] = cls_loss_value.item()
            loss_hist[index_i, 2] = aug_loss_value.item()
            loss_hist[index_i, 3] = miu

            optimizer.step()
            teacher_optimizer.step()
            batch_time = time.time() - tem_time

            if (batch_index + 1) % 10 == 0:
                print(
                    'epoch %d/%d:  %d/%d time: %.2f miu = %.1f cls_loss = %.3f st_loss = %.3f \n'
                    % (epoch + 1, args.num_epoch, batch_index + 1, num_batches,
                       batch_time,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 3]) * 100,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 1]),
                       np.mean(loss_hist[index_i - 9:index_i + 1, 2])))
                f.write(
                    'epoch %d/%d:  %d/%d time: %.2f miu = %.1f cls_loss = %.3f st_loss = %.3f \n'
                    % (epoch + 1, args.num_epoch, batch_index + 1, num_batches,
                       batch_time,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 3]) * 100,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 1]),
                       np.mean(loss_hist[index_i - 9:index_i + 1, 2])))
                f.flush()

            if (batch_index + 1) % 500 == 0:
                OA_new = test_mIoU16(f,
                                     teacher_net,
                                     val_loader,
                                     epoch + 1,
                                     input_size_target,
                                     print_per_batches=10)

                # Saving the models
                if OA_new > OA_hist:
                    f.write('Save Model\n')
                    print('Save Model')
                    model_name = 'Synthia2Cityscapes_epoch' + repr(
                        epoch + 1) + 'batch' + repr(batch_index +
                                                    1) + 'tgt_miu_' + repr(
                                                        int(OA_new *
                                                            1000)) + '.pth'
                    torch.save(teacher_net.state_dict(),
                               os.path.join(args.snapshot_dir, model_name))
                    OA_hist = OA_new

    f.close()
    torch.save(teacher_net.state_dict(),
               os.path.join(args.snapshot_dir, 'Synthia_TeacherNet.pth'))
    np.savez(args.snapshot_dir + 'Synthia_loss.npz', loss_hist=loss_hist)
コード例 #3
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
        # model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    # # load discriminator params
    # saved_state_dict_D1 = torch.load(D1_RESTORE_FROM)
    # saved_state_dict_D2 = torch.load(D2_RESTORE_FROM)
    # model_D1.load_state_dict(saved_state_dict_D1)
    # model_D2.load_state_dict(saved_state_dict_D2)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(synthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=CITY_IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.7, 0.99))
    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.7, 0.99))

    #BYZQ
    # opti_state_dict = torch.load(OPTI_RESTORE_FROM)
    # opti_state_dict_d1 = torch.load(OPTI_D1_RESTORE_FROM)
    # opti_state_dict_d2 = torch.load(OPTI_D2_RESTORE_FROM)
    # optimizer.load_state_dict(opti_state_dict)
    # optimizer_D1.load_state_dict(opti_state_dict_d1)
    # optimizer_D1.load_state_dict(opti_state_dict_d2)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 1
    target_label = 0
    mIoUs = []
    i_iters = []

    for i_iter in range(args.num_steps):
        if i_iter <= iter_start:
            continue

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, name = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            weight_s = float(D_out2.mean().data.cpu().numpy())

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            weight_t = float(D_out2.mean().data.cpu().numpy())
            # if weight_b>0.5 and i_iter>500:
            #     confidence_map = interp(D_out2).cpu().data[0][0].numpy()
            #     name = name[0].split('/')[-1]
            #     confidence_map=255*confidence_map
            #     confidence_output=Image.fromarray(confidence_map.astype(np.uint8))
            #     confidence_output.save('./result/confid_map/%s.png' % (name.split('.')[0]))
            #     zq=1
            print(weight_s, weight_t)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        # print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            torch.save(
                optimizer.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer.pth'))
            torch.save(
                optimizer_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer_D1.pth'))
            torch.save(
                optimizer_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer_D2.pth'))
            show_pred_sv_dir = pre_sv_dir.format(i_iter)
            mIoU = show_val(model.state_dict(), show_pred_sv_dir, gpu)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            i_iters.append(i_iter)
            print_i = 0
            for miou in mIoUs:
                print('i{0}: {1}'.format(i_iters[print_i], miou))
                print_i = print_i + 1
コード例 #4
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")
    cudnn.benchmark = True
    cudnn.enabled = True

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    Iter = 0
    bestIoU = 0

    # Create network
    # init G
    if args.model == 'DeepLab':
        model = DeeplabMultiFeature(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        if args.continue_train:
            if list(saved_state_dict.keys())[0].split('.')[0] == 'module':
                for key in saved_state_dict.keys():
                    saved_state_dict['.'.join(
                        key.split('.')[1:])] = saved_state_dict.pop(key)
            model.load_state_dict(saved_state_dict)
        else:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes).to(device)

    if args.continue_train:
        model_weights_path = args.restore_from
        temp = model_weights_path.split('.')
        temp[-2] = temp[-2] + '_D'
        model_D_weights_path = '.'.join(temp)
        model_D.load_state_dict(torch.load(model_D_weights_path))
        temp = model_weights_path.split('.')
        temp = temp[-2][-9:]
        Iter = int(temp.split('_')[1]) + 1

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # init data loader
    if args.data_dir.split('/')[-1] == 'gta5_deeplab':
        trainset = GTA5DataSet(args.data_dir,
                               args.data_list,
                               max_iters=args.num_steps * args.iter_size *
                               args.batch_size,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)
    elif args.data_dir.split('/')[-1] == 'syn_deeplab':
        trainset = synthiaDataSet(args.data_dir,
                                  args.data_list,
                                  max_iters=args.num_steps * args.iter_size *
                                  args.batch_size,
                                  crop_size=input_size,
                                  scale=args.random_scale,
                                  mirror=args.random_mirror,
                                  mean=IMG_MEAN)

    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    # init optimizer
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level="O2",
                                      keep_batchnorm_fp32=True,
                                      loss_scale="dynamic")

    model_D, optimizer_D = amp.initialize(model_D,
                                          optimizer_D,
                                          opt_level="O2",
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")

    # init loss
    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)
    L1_loss = torch.nn.L1Loss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    test_interp = nn.Upsample(size=(1024, 2048),
                              mode='bilinear',
                              align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # init prototype
    num_prototype = args.num_prototype
    num_ins = args.num_prototype * 10
    src_cls_features = torch.zeros([len(BG_LABEL), num_prototype, 2048],
                                   dtype=torch.float32).to(device)
    src_cls_ptr = np.zeros(len(BG_LABEL), dtype=np.uint64)
    src_ins_features = torch.zeros([len(FG_LABEL), num_ins, 2048],
                                   dtype=torch.float32).to(device)
    src_ins_ptr = np.zeros(len(FG_LABEL), dtype=np.uint64)

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        writer = SummaryWriter(args.log_dir)

    # start training
    for i_iter in range(Iter, args.num_steps):

        loss_seg_value = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cls_value = 0
        loss_ins_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D

            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            src_feature, pred = model(images)
            pred_softmax = F.softmax(pred, dim=1)
            pred_idx = torch.argmax(pred_softmax, dim=1)

            right_label = F.interpolate(labels.unsqueeze(0).float(),
                                        (pred_idx.size(1), pred_idx.size(2)),
                                        mode='nearest').squeeze(0).long()
            right_label[right_label != pred_idx] = 255

            for ii in range(len(BG_LABEL)):
                cls_idx = BG_LABEL[ii]
                mask = right_label == cls_idx
                if torch.sum(mask) == 0:
                    continue
                feature = global_avg_pool(src_feature, mask.float())
                if cls_idx != torch.argmax(
                        torch.squeeze(model.layer6(
                            feature.half()).float())).item():
                    continue
                src_cls_features[ii,
                                 int(src_cls_ptr[ii] %
                                     num_prototype), :] = torch.squeeze(
                                         feature).clone().detach()
                src_cls_ptr[ii] += 1

            seg_ins = seg_label(right_label.squeeze())
            for ii in range(len(FG_LABEL)):
                cls_idx = FG_LABEL[ii]
                segmask, pixelnum = seg_ins[ii]
                if len(pixelnum) == 0:
                    continue
                sortmax = np.argsort(pixelnum)[::-1]
                for i in range(min(10, len(sortmax))):
                    mask = segmask == (sortmax[i] + 1)
                    feature = global_avg_pool(src_feature, mask.float())
                    if cls_idx != torch.argmax(
                            torch.squeeze(
                                model.layer6(feature.half()).float())).item():
                        continue
                    src_ins_features[ii, int(src_ins_ptr[ii] %
                                             num_ins), :] = torch.squeeze(
                                                 feature).clone().detach()
                    src_ins_ptr[ii] += 1

            pred = interp(pred)
            loss_seg = seg_loss(pred, labels)
            loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            trg_feature, pred_target = model(images)

            pred_target_softmax = F.softmax(pred_target, dim=1)
            pred_target_idx = torch.argmax(pred_target_softmax, dim=1)

            loss_cls = torch.zeros(1).to(device)
            loss_ins = torch.zeros(1).to(device)
            if i_iter > 0:
                for ii in range(len(BG_LABEL)):
                    cls_idx = BG_LABEL[ii]
                    if src_cls_ptr[ii] / num_prototype <= 1:
                        continue
                    mask = pred_target_idx == cls_idx
                    feature = global_avg_pool(trg_feature, mask.float())
                    if cls_idx != torch.argmax(
                            torch.squeeze(
                                model.layer6(feature.half()).float())).item():
                        continue
                    ext_feature = feature.squeeze().expand(num_prototype, 2048)
                    loss_cls += torch.min(
                        torch.sum(L1_loss(ext_feature,
                                          src_cls_features[ii, :, :]),
                                  dim=1) / 2048.)

                seg_ins = seg_label(pred_target_idx.squeeze())
                for ii in range(len(FG_LABEL)):
                    cls_idx = FG_LABEL[ii]
                    if src_ins_ptr[ii] / num_ins <= 1:
                        continue
                    segmask, pixelnum = seg_ins[ii]
                    if len(pixelnum) == 0:
                        continue
                    sortmax = np.argsort(pixelnum)[::-1]
                    for i in range(min(10, len(sortmax))):
                        mask = segmask == (sortmax[i] + 1)
                        feature = global_avg_pool(trg_feature, mask.float())
                        feature = feature.squeeze().expand(num_ins, 2048)
                        loss_ins += torch.min(
                            torch.sum(L1_loss(feature,
                                              src_ins_features[ii, :, :]),
                                      dim=1) / 2048.) / min(10, len(sortmax))

            pred_target = interp_target(pred_target)

            D_out = model_D(F.softmax(pred_target, dim=1))
            loss_adv_target = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target * loss_adv_target + args.lambda_adv_cls * loss_cls + args.lambda_adv_ins * loss_ins
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_adv_target_value += loss_adv_target.item() / args.iter_size

            # train D

            # bring back requires_grad

            for param in model_D.parameters():
                param.requires_grad = True

            # train with source
            pred = pred.detach()
            D_out = model_D(F.softmax(pred, dim=1))

            loss_D = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))
            loss_D = loss_D / args.iter_size / 2
            amp_backward(loss_D, optimizer_D)
            loss_D_value += loss_D.item()

            # train with target
            pred_target = pred_target.detach()
            D_out = model_D(F.softmax(pred_target, dim=1))

            loss_D = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(target_label).to(device))
            loss_D = loss_D / args.iter_size / 2
            amp_backward(loss_D, optimizer_D)
            loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv = {3:.3f} loss_D = {4:.3f} loss_cls = {5:.3f} loss_ins = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_target_value, loss_D_value, loss_cls.item(),
                    loss_ins.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            testloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target_test,
                crop_size=(1024, 512),
                mean=IMG_MEAN,
                scale=False,
                mirror=False,
                set='val'),
                                         batch_size=1,
                                         shuffle=False,
                                         pin_memory=True)
            model.eval()
            for index, batch in enumerate(testloader):
                image, _, name = batch
                with torch.no_grad():
                    output1, output2 = model(Variable(image).to(device))
                output = test_interp(output2).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output = Image.fromarray(output)
                name = name[0].split('/')[-1]
                output.save('%s/%s' % (args.save, name))
            mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'),
                                 args.save, 'dataset/cityscapes_list')
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            if mIoU > bestIoU:
                bestIoU = mIoU
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5.pth'))
                torch.save(model_D.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5_D.pth'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
            model.train()

    if args.tensorboard:
        writer.close()