コード例 #1
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True


    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in currendt model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)


    model.train()


    model=nn.DataParallel(model)
    model.cuda()


    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))

    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)


    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)


    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)


    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')


    # labels for adversarial training
    pred_label = 0
    gt_label = 1


    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):


            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda()


                pred = interp(model(images))
                pred_remain = pred.detach()

                mask1=F.softmax(pred,dim=1).data.cpu().numpy()

                id2 = np.argmax(mask1, axis=1)#10, 321, 321)


                D_out = interp(model_D(F.softmax(pred,dim=1)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)


                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)
                    #print semi_ignore_mask.shape 10,321,321

                    map2 = np.zeros([pred.size()[0], id2.shape[1], id2.shape[2]])
                    for k in  range(pred.size()[0]):
                        for i in range(id2.shape[1]):
                            for j in range(id2.shape[2]):
                               map2[k][i][j] = mask1[k][id2[k][i][j]][i][j]


                    semi_ignore_mask = (map2 <  0.999999)
                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels)

            D_out = interp(model_D(F.softmax(pred,dim=1)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size


            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)


            D_out = interp(model_D(F.softmax(pred,dim=1)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]



        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps-1:
            print( 'save model ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
コード例 #2
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    #cudnn.enabled = True
    gpu = args.gpu

    # Create network
    #if args.model == 'DeepLab':
    #    model = DeeplabMulti(num_classes=args.num_classes)
    #    if args.restore_from[:4] == 'http' :
    #        saved_state_dict = model_zoo.load_url(args.restore_from)
    #    else:
    #        saved_state_dict = torch.load(args.restore_from)

    #    new_params = model.state_dict().copy()
    #    for i in saved_state_dict:
    #        # Scale.layer5.conv2d_list.3.weight
    #        i_parts = i.split('.')
    #        # print i_parts
    #        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
    #            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    # print i_parts
    #    model.load_state_dict(new_params)

    #model.train()
    #model.cuda(args.gpu)

    #cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(0)
    #model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = FCDiscriminator(num_classes=args.num_classes)

    #model_D1.train()
    #model_D1.cuda(args.gpu)

    #model_D2.train()
    #model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    gta_trainloader = data.DataLoader(GTA5DataSet(
        args.data_dir,
        args.data_list,
        args.num_classes,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    gta_trainloader_iter = enumerate(gta_trainloader)

    gta_valloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                                args.valdata_list,
                                                args.num_classes,
                                                max_iters=None,
                                                crop_size=input_size,
                                                scale=args.random_scale,
                                                mirror=args.random_mirror,
                                                mean=IMG_MEAN),
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers,
                                    pin_memory=True)

    gta_valloader_iter = enumerate(gta_valloader)

    cityscapes_targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        args.num_classes,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set='train'),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    cityscapes_targetloader_iter = enumerate(cityscapes_targetloader)

    cityscapes_valtargetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.valdata_list_target,
        args.num_classes,
        max_iters=None,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set='val'),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.num_workers,
                                                 pin_memory=True)

    cityscapes_valtargetloader_iter = enumerate(cityscapes_valtargetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    #optimizer = optim.SGD(model.optim_parameters(args),
    #                      lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    #optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    #optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    #optimizer_D1.zero_grad()

    #optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    #optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    #interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_D_value = 0
        #loss_seg_value1 = 0
        #loss_adv_target_value1 = 0
        #loss_D_value1 = 0

        #loss_seg_value2 = 0
        #loss_adv_target_value2 = 0
        #loss_D_value2 = 0

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        _, batch = gta_trainloader_iter.next()
        images, labels, _, _ = batch
        size = labels.size()
        #print(size)
        #labels = Variable(labels)
        oneHot_size = (size[0], args.num_classes, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1, labels.long(), 1.0)
        #print(input_label.size())
        labels1 = Variable(input_label).cuda(0)
        #D_out1 = model_D(labels)

        #print(D_out1.data.size())
        #loss_out1 = bce_loss(D_out1, Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(0))

        _, batch = cityscapes_targetloader_iter.next()
        images, labels, _, _ = batch
        size = labels.size()
        #labels = Variable(labels)
        oneHot_size = (size[0], args.num_classes, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()

        input_label = input_label.scatter_(1, labels.long(), 1.0)
        labels2 = Variable(input_label).cuda(0)

        #print(labels1.data.size())
        #print(labels2.data.size())
        labels = torch.cat((labels1, labels2), 0)
        #print(labels.data.size())

        #D_out2 = model_D(labels)
        D_out = model_D(labels)
        #print(D_out.data.size())

        target_size = D_out.data.size()
        target_labels1 = torch.FloatTensor(
            torch.Size((target_size[0] / 2, target_size[1], target_size[2],
                        target_size[3]))).fill_(source_label)
        target_labels2 = torch.FloatTensor(
            torch.Size((target_size[0] / 2, target_size[1], target_size[2],
                        target_size[3]))).fill_(target_label)
        target_labels = torch.cat((target_labels1, target_labels2), 0)
        target_labels = Variable(target_labels).cuda(0)
        #print(target_labels.data.size())
        loss_out = bce_loss(D_out, target_labels)

        loss = loss_out / args.iter_size
        loss.backward()

        loss_D_value += loss_out.data.cpu().numpy()[0] / args.iter_size

        #print(loss_D_value)
        optimizer_D.step()

        #print('exp = {}'.format(args.snapshot_dir))
        if i_iter % 100 == 0:
            print('iter = {0:8d}/{1:8d}, loss_D = {2:.3f}'.format(
                i_iter, args.num_steps, loss_D_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'Classify_' + str(args.num_steps) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'Classify_' + str(i_iter) + '.pth'))

            model_D.eval()
            loss_valD_value = 0
            correct = 0
            wrong = 0
            for i, (images, labels, _, _) in enumerate(gta_valloader):
                #if i > 500:
                #    break
                size = labels.size()
                #labels = Variable(labels)
                oneHot_size = (size[0], args.num_classes, size[2], size[3])
                input_label = torch.FloatTensor(
                    torch.Size(oneHot_size)).zero_()
                input_label = input_label.scatter_(1, labels.long(), 1.0)
                labels = Variable(input_label).cuda(0)
                D_out1 = model_D(labels)
                loss_out1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label)).cuda(0))
                loss_valD_value += loss_out1.data.cpu().numpy()[0]
                correct = correct + (D_out1.data.cpu() < 0).sum() / 100
                wrong = wrong + (D_out1.data.cpu() >= 0).sum() / 100
                #accuracy = 1.0 * correct / (wrong + correct)
                #print('accuracy:%f' % accuracy)
                #print(correct)
                #print(wrong)

            for i, (images, labels, _,
                    _) in enumerate(cityscapes_valtargetloader):
                #if i > 500:
                #    break
                size = labels.size()
                #labels = Variable(labels)
                oneHot_size = (size[0], args.num_classes, size[2], size[3])
                input_label = torch.FloatTensor(
                    torch.Size(oneHot_size)).zero_()
                input_label = input_label.scatter_(1, labels.long(), 1.0)
                labels = Variable(input_label).cuda(0)
                D_out2 = model_D(labels)
                loss_out2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(target_label)).cuda(0))
                loss_valD_value += loss_out2.data.cpu().numpy()[0]
                wrong = wrong + (D_out2.data.cpu() < 0).sum()
                correct = correct + (D_out2.data.cpu() >= 0).sum()
            accuracy = 1.0 * correct / (wrong + correct)
            print('accuracy:%f' % accuracy)

            print('iter = {0:8d}/{1:8d}, loss_valD = {2:.3f}'.format(
                i_iter, args.num_steps, loss_valD_value))

            model_D.train()
コード例 #3
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
        # model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    # # load discriminator params
    # saved_state_dict_D1 = torch.load(D1_RESTORE_FROM)
    # saved_state_dict_D2 = torch.load(D2_RESTORE_FROM)
    # model_D1.load_state_dict(saved_state_dict_D1)
    # model_D2.load_state_dict(saved_state_dict_D2)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(synthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=CITY_IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.7, 0.99))
    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.7, 0.99))

    #BYZQ
    # opti_state_dict = torch.load(OPTI_RESTORE_FROM)
    # opti_state_dict_d1 = torch.load(OPTI_D1_RESTORE_FROM)
    # opti_state_dict_d2 = torch.load(OPTI_D2_RESTORE_FROM)
    # optimizer.load_state_dict(opti_state_dict)
    # optimizer_D1.load_state_dict(opti_state_dict_d1)
    # optimizer_D1.load_state_dict(opti_state_dict_d2)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 1
    target_label = 0
    mIoUs = []
    i_iters = []

    for i_iter in range(args.num_steps):
        if i_iter <= iter_start:
            continue

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, name = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            weight_s = float(D_out2.mean().data.cpu().numpy())

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            weight_t = float(D_out2.mean().data.cpu().numpy())
            # if weight_b>0.5 and i_iter>500:
            #     confidence_map = interp(D_out2).cpu().data[0][0].numpy()
            #     name = name[0].split('/')[-1]
            #     confidence_map=255*confidence_map
            #     confidence_output=Image.fromarray(confidence_map.astype(np.uint8))
            #     confidence_output.save('./result/confid_map/%s.png' % (name.split('.')[0]))
            #     zq=1
            print(weight_s, weight_t)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        # print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            torch.save(
                optimizer.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer.pth'))
            torch.save(
                optimizer_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer_D1.pth'))
            torch.save(
                optimizer_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer_D2.pth'))
            show_pred_sv_dir = pre_sv_dir.format(i_iter)
            mIoU = show_val(model.state_dict(), show_pred_sv_dir, gpu)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            i_iters.append(i_iter)
            print_i = 0
            for miou in mIoUs:
                print('i{0}: {1}'.format(i_iters[print_i], miou))
                print_i = print_i + 1
コード例 #4
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        #model.load_state_dict(new_params)
        
    if CONTINUE_FLAG==1:
        model.load_state_dict(saved_state_dict)    
    
    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D_a = FCDiscriminator(num_classes=256)  # need to check
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)
    
    if CONTINUE_FLAG==1:
        d1_saved_state_dict = torch.load(D1_RESTORE_FROM)
        d2_saved_state_dict = torch.load(D2_RESTORE_FROM)
        model_D1.load_state_dict(d1_saved_state_dict)
        model_D2.load_state_dict(d2_saved_state_dict)
        
    model_D_a.train()
    model_D_a.cuda(args.gpu)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        synthiaDataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D_a = optim.Adam(model_D_a.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D_a.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1
    mIoUs = []
    for i_iter in range(continue_start_iter+1,args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D_a.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D_a, i_iter)
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)



        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D_a.parameters():
                param.requires_grad = False

            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_a, pred1, pred2 = model(images)
            pred1=nn.functional.interpolate(pred1,size=(input_size[1], input_size[0]), mode='bilinear',align_corners=True)
            pred2 = nn.functional.interpolate(pred2, size=(input_size[1], input_size[0]), mode='bilinear',
                                              align_corners=True)
            # pred1 = interp(pred1)
            # pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            lambda_wtight = (80000 - i_iter) / 80000
            if lambda_wtight > 0:
                pred_target_a, _, _ = model(images)
                D_out_a = model_D_a(pred_target_a)
                loss_adv_target_a = bce_loss(D_out_a,
                                             Variable(torch.FloatTensor(D_out_a.data.size()).fill_(source_label)).cuda(
                                                 args.gpu))
                loss_adv_target_a=LAMBDA_ADV_TARGET_A *loss_adv_target_a
                loss_adv_target_a = loss_adv_target_a / args.iter_size
                loss_adv_target_a.backward()

            _, pred_target1, pred_target2 = model(images)

            pred_target1 = nn.functional.interpolate(pred_target1,size=(input_size_target[1], input_size_target[0]), mode='bilinear',align_corners=True)
            pred_target2 = nn.functional.interpolate(pred_target2, size=(input_size_target[1], input_size_target[0]),
                                                     mode='bilinear', align_corners=True)

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D_a.parameters():
                param.requires_grad = True

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            if lambda_wtight > 0:
                pred_a=pred_a.detach()
                D_out_a = model_D_a(pred_a)
                loss_D_a = bce_loss(D_out_a,
                                    Variable(torch.FloatTensor(D_out_a.data.size()).fill_(source_label)).cuda(args.gpu))
                loss_D_a = loss_D_a / args.iter_size / 2
                loss_D_a.backward()

            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1,dim=1))
            D_out2 = model_D2(F.softmax(pred2,dim=1))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            if lambda_wtight > 0:
                pred_target_a=pred_target_a.detach()
                D_out_a = model_D_a(pred_target_a)
                loss_D_a = bce_loss(D_out_a,
                                    Variable(torch.FloatTensor(D_out_a.data.size()).fill_(target_label)).cuda(args.gpu))
                loss_D_a = loss_D_a / args.iter_size / 2
                loss_D_a.backward()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D_a.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D2.pth'))
            show_val(model.state_dict(), LAMBDA_ADV_TARGET_A ,i_iter)
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            mIoU=show_val(model.state_dict(), LAMBDA_ADV_TARGET_A ,i_iter)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            for miou in mIoUs:
                print(miou)
コード例 #5
0
def main():
    """Create the model and start the training."""
    model_num = 0

    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    random.seed(args.random_seed)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        if args.training_option == 1:
            model = Res_Deeplab(num_classes=args.num_classes,
                                num_layers=args.num_layers)
        elif args.training_option == 2:
            model = Res_Deeplab2(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for k, v in saved_state_dict.items():
            print(k)

        for k in new_params:
            print(k)

        for i in saved_state_dict:
            i_parts = i.split('.')

            if '.'.join(i_parts[args.i_parts_index:]) in new_params:
                print("Restored...")
                if args.not_restore_last == True:
                    if not i_parts[
                            args.i_parts_index] == 'layer5' and not i_parts[
                                args.i_parts_index] == 'layer6':
                        new_params['.'.join(i_parts[args.i_parts_index:]
                                            )] = saved_state_dict[i]
                else:
                    new_params['.'.join(
                        i_parts[args.i_parts_index:])] = saved_state_dict[i]

        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    writer = SummaryWriter(log_dir=args.snapshot_dir)

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    '''trainloader = data.DataLoader(sourceDataSet(args.data_dir, 
                                                    args.data_list, 
                                                    max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                    crop_size=input_size,
                                                    scale=args.random_scale, 
                                                    mirror=args.random_mirror, 
                                                    mean=IMG_MEAN_SOURCE,
                                                    ignore_label=args.ignore_label),
                                  batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)'''

    trainloader = data.DataLoader(sourceDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=args.augment_1,
        random_flip=args.augment_1,
        random_lighting=args.augment_1,
        mean=IMG_MEAN_SOURCE,
        ignore_label=args.ignore_label),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(isprsDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN_TARGET,
        ignore_label=args.ignore_label),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    valloader = data.DataLoader(valDataSet(args.data_dir_val,
                                           args.data_list_val,
                                           crop_size=input_size_target,
                                           mean=IMG_MEAN_TARGET,
                                           scale=False,
                                           mirror=False),
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # EDITTED by me
    interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[0],
                                      input_size_target[1]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # Which layers to freeze
    non_trainable(args.dont_train, model)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            while True:
                try:

                    _, batch = next(trainloader_iter)
                    images, labels, _, train_name = batch
                    #print(train_name)
                    images = Variable(images).cuda(args.gpu)

                    pred1, pred2 = model(images)
                    pred1 = interp(pred1)
                    pred2 = interp(pred2)

                    # Save img
                    '''if i_iter % 5 == 0:
                        save_image_for_test(concatenate_side_by_side([images, labels, pred2]), i_iter)'''

                    loss_seg1 = loss_calc(pred1, labels, args.gpu,
                                          args.ignore_label, train_name)
                    loss_seg2 = loss_calc(pred2, labels, args.gpu,
                                          args.ignore_label, train_name)

                    loss = loss_seg2 + args.lambda_seg * loss_seg1

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()

                    if isinstance(loss_seg1.data.cpu().numpy(), list):
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        ) / args.iter_size

                    if isinstance(loss_seg2.data.cpu().numpy(), list):
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        ) / args.iter_size
                    break
                except (RuntimeError, AssertionError, AttributeError):
                    continue

            if args.experiment == 1:
                # Which layers to freeze
                non_trainable('0', model)

            # train with target
            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            #total_image2 = vutils.make_grid(torch.cat((images.cuda()), dim = 2),normalize=True, scale_each=True)
            #total_image2 = images.cuda()
            #, pred_target1.cuda(), pred_target2.cuda()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()

            if isinstance(loss_adv_target1.data.cpu().numpy(), list):
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
                )[0] / args.iter_size
            else:
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
                ) / args.iter_size

            if isinstance(loss_adv_target2.data.cpu().numpy(), list):
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
                )[0] / args.iter_size
            else:
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
                ) / args.iter_size

            if args.experiment == 1:
                # Which layers to freeze
                non_trainable(args.dont_train, model)

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            #print ('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'model_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'model_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'model_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            #print ('taking snapshot ...')
            '''torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(i_iter) + '_D2.pth'))'''
            if model_num != args.num_models_keep:
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D2.pth'))
                model_num = model_num + 1
            if model_num == args.num_models_keep:
                model_num = 0

        # Validation
        if (i_iter % args.val_every == 0 and i_iter != 0) or i_iter == 1:
            validation(valloader, model, interp_target, writer, i_iter,
                       [37, 41, 10])

        # Save for tensorboardx
        writer.add_scalar('loss_seg_value1', loss_seg_value1, i_iter)
        writer.add_scalar('loss_seg_value2', loss_seg_value2, i_iter)
        writer.add_scalar('loss_adv_target_value1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar('loss_adv_target_value2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar('loss_D_value1', loss_D_value1, i_iter)
        writer.add_scalar('loss_D_value2', loss_D_value2, i_iter)

    writer.close()
コード例 #6
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=16,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=16,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_gt_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader_all = data.DataLoader(train_dataset,
                                          batch_size=args.batch_size,
                                          sampler=train_sampler_all,
                                          num_workers=16,
                                          pin_memory=True)
        trainloader_gt_all = data.DataLoader(train_gt_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_gt_sampler_all,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=16,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=16,
                                         pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    trainloader_all_iter = iter(trainloader_all)
    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    #y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda())

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_fm_value = 0
        loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            try:
                batch = next(trainloader_iter)
            except:
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            #ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_seg.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size

            if i_iter >= args.adv_start:

                #fm loss calc
                try:
                    batch = next(trainloader_all_iter)
                except:
                    trainloader_iter = iter(trainloader_all)
                    batch = next(trainloader_all_iter)

                images, labels, _, _ = batch
                images = Variable(images).cuda(args.gpu)
                #ignore_mask = (labels.numpy() == 255)
                pred = interp(model(images))

                _, D_out_y_pred = model_D(F.softmax(pred))

                trainloader_gt_iter = iter(trainloader_gt)
                batch = next(trainloader_gt_iter)

                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                _, D_out_y_gt = model_D(D_gt_v)

                fm_loss = torch.mean(
                    torch.abs(
                        torch.mean(D_out_y_gt, 0) -
                        torch.mean(D_out_y_pred, 0)))

                loss = loss_seg + args.lambda_fm * fm_loss

                # proper normalization
                fm_loss.backward()
                #loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
                loss_fm_value += fm_loss.data.cpu().numpy()[0] / args.iter_size
                loss_value += loss.data.cpu().numpy()[0] / args.iter_size

                # train D

                # bring back requires_grad
                for param in model_D.parameters():
                    param.requires_grad = True

                # train with pred
                pred = pred.detach()

                D_out_z, _ = model_D(F.softmax(pred))
                y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
                loss_D_fake = criterion(D_out_z, y_fake_)

                # train with gt
                # get gt labels
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                D_out_z_gt, _ = model_D(D_gt_v)
                #D_out = interp(D_out_x)

                y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())

                loss_D_real = criterion(D_out_z_gt, y_real_)
                loss_D = loss_D_fake + loss_D_real
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_D = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value, loss_D_value))
        print('fm_loss: ', loss_fm_value, ' g_loss: ', loss_value)

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
コード例 #7
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMultiFeature(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    cityset = cityscapesDataSet(args.data_dir_target,
                                args.data_list_target,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size_target,
                                scale=False,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN,
                                set=args.set)
    targetloader = data.DataLoader(cityset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    # init cls D
    model_clsD = []
    optimizer_clsD = []
    for i in range(args.num_classes):
        model_temp = FCDiscriminatorCLS(
            num_classes=args.num_classes).to(device).train()
        optimizer_temp = optim.Adam(model_temp.parameters(),
                                    lr=args.learning_rate_D,
                                    betas=(0.9, 0.99))
        optimizer_temp.zero_grad()
        #model_temp, optimizer_temp = amp.initialize(
        #    model_temp, optimizer_temp, opt_level="O1",
        #    keep_batchnorm_fp32=None, loss_scale="dynamic"
        #)
        model_temp, optimizer_temp = amp.initialize(model_temp,
                                                    optimizer_temp,
                                                    opt_level="O1",
                                                    keep_batchnorm_fp32=None,
                                                    loss_scale="dynamic")
        model_clsD.append(model_temp)
        optimizer_clsD.append(optimizer_temp)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level="O1",
                                      keep_batchnorm_fp32=None,
                                      loss_scale="dynamic")

    model_D2, optimizer_D2 = amp.initialize(model_D2,
                                            optimizer_D2,
                                            opt_level="O1",
                                            keep_batchnorm_fp32=None,
                                            loss_scale="dynamic")

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    cls_begin_iter = 10000

    num_target_imgs = 2975
    predicted_label = np.zeros(
        (num_target_imgs, 1, input_size_target[1], input_size_target[0]),
        dtype=np.uint8)
    predicted_prob = np.zeros(
        (num_target_imgs, 1, input_size_target[1], input_size_target[0]),
        dtype=np.float16)
    name2idxmap = {}
    for i in range(num_target_imgs):
        name2idxmap[cityset.files[i]['name']] = i
    thres = []

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_cls_adv = 0
        loss_cls_adv_value = 0
        loss_cls_D = 0
        loss_cls_D_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        if i_iter >= cls_begin_iter:
            for i in range(args.num_classes):
                optimizer_clsD[i].zero_grad()
                adjust_learning_rate_D(optimizer_clsD[i], i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            if i_iter >= cls_begin_iter:
                for i in range(args.num_classes):
                    for param in model_clsD[i].parameters():
                        param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            _, pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, name = batch
            images = images.to(device)
            name = name[0]
            img_idx = name2idxmap[name]

            _, pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            pred_target_score = F.softmax(pred_target2, dim=1)
            D_out2 = model_D2(pred_target_score)
            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            target_pred_prob, target_pred_cls = torch.max(pred_target_score,
                                                          dim=1)
            predicted_label[img_idx,
                            ...] = target_pred_cls.cpu().data.numpy().astype(
                                np.uint8)
            predicted_prob[img_idx,
                           ...] = target_pred_prob.cpu().data.numpy().astype(
                               np.float16)

            if i_iter >= cls_begin_iter and i_iter % 5000 == 0:
                thres = []
                for i in range(args.num_classes):
                    x = predicted_prob[predicted_label == i]
                    if len(x) == 0:
                        thres.append(0)
                        continue
                    x = np.sort(x)
                    thres.append(x[np.int(np.round(len(x) * 0.5))])
                print(thres)
                thres = np.array(thres)
                thres[thres > 0.9] = 0.9
                np.save(osp.join(args.snapshot_dir, 'predicted_label'),
                        predicted_label)
                np.save(osp.join(args.snapshot_dir, 'predicted_prob'),
                        predicted_prob)

            if i_iter >= cls_begin_iter:
                target_pred_cls = target_pred_cls.long().detach()
                for i in range(args.num_classes):
                    cls_mask = (target_pred_cls
                                == i) * (target_pred_cls >= thres[i])
                    if torch.sum(cls_mask) == 0:
                        continue
                    cls_gt = torch.tensor(
                        target_pred_cls.data).long().to(device)
                    cls_gt[~cls_mask] = 255
                    cls_gt[cls_mask] = source_label
                    cls_out = model_clsD[i](pred_target_score)
                    loss_cls_adv += seg_loss(cls_out, cls_gt)
                loss_cls_adv_value = loss_cls_adv.item() / args.iter_size

            loss = args.lambda_adv_target2 * loss_adv_target2 + LAMBDA_CLS_ADV * loss_cls_adv
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))
            loss_D2 = loss_D2 / args.iter_size / 2
            amp_backward(loss_D2, optimizer_D2)
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target2 = pred_target2.detach()
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))
            loss_D2 = loss_D2 / args.iter_size / 2
            amp_backward(loss_D2, optimizer_D2)
            loss_D_value2 += loss_D2.item()

            if i_iter >= cls_begin_iter:
                for i in range(args.num_classes):
                    for param in model_clsD[i].parameters():
                        param.requires_grad = True

                pred_source_score = F.softmax(pred2, dim=1)
                source_pred_prob, source_pred_cls = torch.max(
                    pred_source_score, dim=1)
                source_pred_cls = source_pred_cls.long().detach()
                for i in range(args.num_classes):
                    cls_mask = (source_pred_cls == i) * (labels == i)
                    if torch.sum(cls_mask) == 0:
                        continue
                    cls_gt = torch.tensor(
                        source_pred_cls.data).long().to(device)
                    cls_gt[~cls_mask] = 255
                    cls_gt[cls_mask] = source_label
                    cls_out = model_clsD[i](pred_source_score)
                    loss_cls_D = seg_loss(cls_out, cls_gt) / 2
                    amp_backward(loss_cls_D, optimizer_clsD[i])
                    loss_cls_D_value += loss_cls_D.item()

                pred_target_score = F.softmax(pred_target2, dim=1)
                target_pred_prob, target_pred_cls = torch.max(
                    pred_target_score, dim=1)
                target_pred_cls = target_pred_cls.long().detach()
                for i in range(args.num_classes):
                    cls_mask = (target_pred_cls
                                == i) * (target_pred_cls >= thres[i])
                    if torch.sum(cls_mask) == 0:
                        continue
                    cls_gt = torch.tensor(
                        target_pred_cls.data).long().to(device)
                    cls_gt[~cls_mask] = 255
                    cls_gt[cls_mask] = target_label
                    cls_out = model_clsD[i](pred_target_score)
                    loss_cls_adv += seg_loss(cls_out, cls_gt)
                    loss_cls_D = seg_loss(cls_out, cls_gt) / 2
                    amp_backward(loss_cls_D, optimizer_clsD[i])
                    loss_cls_D_value += loss_cls_D.item()

        optimizer.step()
        optimizer_D2.step()
        if i_iter >= cls_begin_iter:
            for i in range(args.num_classes):
                optimizer_clsD[i].step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg2': loss_seg_value2,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg2 = {2:.3f}, loss_adv2 = {3:.3f} loss_D2 = {4:.3f} loss_cls_adv = {5:.3f} loss_cls_D = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value2,
                    loss_adv_target_value2, loss_D_value2, loss_cls_adv_value,
                    loss_cls_D_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            for i in range(args.num_classes):
                torch.save(
                    model_clsD[i].state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(NUM_STEPS) + '_clsD.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            for i in range(args.num_classes):
                torch.save(
                    model_clsD[i].state_dict(),
                    osp.join(args.snapshot_dir, 'GTA5_clsD' + str(i) + '.pth'))

    if args.tensorboard:
        writer.close()
コード例 #8
0
def main():
    """Create the model and start the training."""

    gpu_id_2 = 3
    gpu_id_1 = 2

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(args.restore_from)
            saved_state_dict = torch.load(
                'snapshots/GTA2Cityscapes_multi_54/GTA5_10000.pth')
            model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        # model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = model_D1
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.next()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    # print(args.num_steps * args.iter_size * args.batch_size, trainloader.__len__())

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.next()

    # for i in range(200):
    #     _, batch = targetloader_iter.__next__()
    # exit()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(10000, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            # train with source
            # _, batch = trainloader_iter.next()
            _, batch = trainloader_iter.next()
            _, batch_target = targetloader_iter.next()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size
            # print(loss_seg_1, loss_seg_2)

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # output = pred2.cpu().data[0].numpy()
            # real_lab = labels.cpu().data[0].numpy()
            # output = output.transpose(1,2,0)
            # print(real_lab.shape, output.shape)
            # output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            # output_col = colorize_mask(output)
            # real_lab_col = colorize_mask(real_lab)
            # output = Image.fromarray(output)
            # # name[0].split('/')[-1]
            # # print('result/train_seg_result/' + name[0][len(name[0])-23:len(name[0])-4] + '_color.png')
            # output_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_color.png')
            # real_lab_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_real.png')
            # print(loss_seg_value1, loss_seg_value2)
            # if i_iter == 100:
            #     exit()
            # else:
            #     break

            # train with target

            #_, batch = targetloader_iter.next()
            # images, _, _ = target_batch
            # images_target = Variable(images_target).cuda(gpu_id_2)

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D2(fake1_D)

            loss_adv_fake1 = bce_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = bce_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D2(mix2_D)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_adv_mix1 = bce_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = bce_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D2(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D2(mix2_D_)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            # pred_target1 = pred_target1.detach().cuda(gpu_id_1)
            # pred_target2 = pred_target2.detach().cuda(gpu_id_1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D2(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D2(mix2_D__)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
コード例 #9
0
ファイル: train_gta2cityscapes.py プロジェクト: Turlan/LTIR
def main():
    """Create the model and start the training."""
    w, h = map(int, args.input_size_source.split(','))
    input_size_source = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        # model.load_state_dict(saved_state_dict)

    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           pretrained=True,
                           vgg16_caffe_path=args.restore_from)

        # saved_state_dict = torch.load(args.restore_from)
        # model.load_state_dict(saved_state_dict)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    model.train()
    model.cuda(args.gpu)
    cudnn.benchmark = True

    #Discrimintator setting
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(args.gpu)

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # labels for adversarial training
    source_adv_label = 0
    target_adv_label = 1

    #Dataloader
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.translated_data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size_source,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    style_trainloader = data.DataLoader(GTA5DataSet(
        args.stylized_data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_source,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

    style_trainloader_iter = enumerate(style_trainloader)

    if STAGE == 1:
        targetloader = data.DataLoader(cityscapesDataSet(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    else:
        #Dataloader for self-training
        targetloader = data.DataLoader(cityscapesDataSetLabel(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set,
            label_folder='Path to generated pseudo labels'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # load checkpoint
    model, model_D, optimizer, start_iter = load_checkpoint(
        model,
        model_D,
        optimizer,
        filename=args.snapshot_dir + 'checkpoint_' + CHECKPOINT + '.pth.tar')

    for i_iter in range(start_iter, args.num_steps):
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        #train segementation network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        if STAGE == 1:
            if i_iter % 2 == 0:
                _, batch = next(trainloader_iter)
            else:
                _, batch = next(style_trainloader_iter)

        else:
            _, batch = next(trainloader_iter)

        image_source, label, _, _ = batch
        image_source = Variable(image_source).cuda(args.gpu)

        pred_source = model(image_source)
        pred_source = interp(pred_source)

        loss_seg_source = loss_calc(pred_source, label, args.gpu)
        loss_seg_source_value = loss_seg_source.item()
        loss_seg_source.backward()

        if STAGE == 2:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, target_label, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #target segmentation loss
            loss_seg_target = loss_calc(pred_target,
                                        target_label,
                                        gpu=args.gpu)
            loss_seg_target.backward()

        # optimize
        optimizer.step()

        if STAGE == 1:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #output-level adversarial training
            D_output_target = model_D(F.softmax(pred_target))
            loss_adv = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_adv = loss_adv * args.lambda_adv
            loss_adv.backward()

            #train discriminator
            for param in model_D.parameters():
                param.requires_grad = True

            pred_source = pred_source.detach()
            pred_target = pred_target.detach()

            D_output_source = model_D(F.softmax(pred_source))
            D_output_target = model_D(F.softmax(pred_target))

            loss_D_source = bce_loss(
                D_output_source,
                Variable(
                    torch.FloatTensor(D_output_source.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_D_target = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        target_adv_label)).cuda(args.gpu))

            loss_D_source = loss_D_source / 2
            loss_D_target = loss_D_target / 2

            loss_D_source.backward()
            loss_D_target.backward()

            #optimize
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg_source = {2:.5f}'.format(
            i_iter, args.num_steps, loss_seg_source_value))

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            state = {
                'iter': i_iter,
                'model': model.state_dict(),
                'model_D': model_D.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                osp.join(args.snapshot_dir,
                         'checkpoint_' + str(i_iter) + '.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_D_' + str(i_iter) + '.pth'))

            cityscapes_eval_dir = osp.join(args.cityscapes_eval_dir,
                                           str(i_iter))
            if not os.path.exists(cityscapes_eval_dir):
                os.makedirs(cityscapes_eval_dir)

            eval(osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'),
                 cityscapes_eval_dir, i_iter)

            iou19, iou13, iou = compute_mIoU(cityscapes_eval_dir, i_iter)
            outputfile = open(args.output_file, 'a')
            outputfile.write(
                str(i_iter) + '\t' + str(iou19) + '\t' +
                str(iou.replace('\n', ' ')) + '\n')
            outputfile.close()
コード例 #10
0
def train(log_file, arch, dataset, batch_size, iter_size, num_workers,
          partial_data, partial_data_size, partial_id, ignore_label, crop_size,
          eval_crop_size, is_training, learning_rate, learning_rate_d,
          supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t,
          semi_start, semi_start_adv, d_remain, momentum, not_restore_last,
          num_steps, power, random_mirror, random_scale, random_seed,
          restore_from, restore_from_d, eval_every, save_snapshot_every,
          snapshot_dir, weight_decay, device):
    settings = locals().copy()

    import cv2
    import torch
    import torch.nn as nn
    from torch.utils import data, model_zoo
    import numpy as np
    import pickle
    import torch.optim as optim
    import torch.nn.functional as F
    import scipy.misc
    import sys
    import os
    import os.path as osp
    import pickle

    from model.deeplab import Res_Deeplab
    from model.unet import unet_resnet50
    from model.deeplabv3 import resnet101_deeplabv3
    from model.discriminator import FCDiscriminator
    from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d
    from utils.evaluation import EvaluatorIoU
    from dataset.voc_dataset import VOCDataSet
    import logger

    torch_device = torch.device(device)

    import time

    if log_file != '' and log_file != 'none':
        if os.path.exists(log_file):
            print('Log file {} already exists; exiting...'.format(log_file))
            return

    with logger.LogFile(log_file if log_file != 'none' else None):
        if dataset == 'pascal_aug':
            ds = VOCDataSet(augmented_pascal=True)
        elif dataset == 'pascal':
            ds = VOCDataSet(augmented_pascal=False)
        else:
            print('Dataset {} not yet supported'.format(dataset))
            return

        print('Command: {}'.format(sys.argv[0]))
        print('Arguments: {}'.format(' '.join(sys.argv[1:])))
        print('Settings: {}'.format(', '.join([
            '{}={}'.format(k, settings[k])
            for k in sorted(list(settings.keys()))
        ])))

        print('Loaded data')

        def loss_calc(pred, label):
            """
            This function returns cross entropy loss for semantic segmentation
            """
            # out shape batch_size x channels x h x w -> batch_size x channels x h x w
            # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
            label = label.long().to(torch_device)
            criterion = CrossEntropy2d()

            return criterion(pred, label)

        def lr_poly(base_lr, iter, max_iter, power):
            return base_lr * ((1 - float(iter) / max_iter)**(power))

        def adjust_learning_rate(optimizer, i_iter):
            lr = lr_poly(learning_rate, i_iter, num_steps, power)
            optimizer.param_groups[0]['lr'] = lr
            if len(optimizer.param_groups) > 1:
                optimizer.param_groups[1]['lr'] = lr * 10

        def adjust_learning_rate_D(optimizer, i_iter):
            lr = lr_poly(learning_rate_d, i_iter, num_steps, power)
            optimizer.param_groups[0]['lr'] = lr
            if len(optimizer.param_groups) > 1:
                optimizer.param_groups[1]['lr'] = lr * 10

        def one_hot(label):
            label = label.numpy()
            one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1],
                                label.shape[2]),
                               dtype=label.dtype)
            for i in range(ds.num_classes):
                one_hot[:, i, ...] = (label == i)
            #handle ignore labels
            return torch.tensor(one_hot,
                                dtype=torch.float,
                                device=torch_device)

        def make_D_label(label, ignore_mask):
            ignore_mask = np.expand_dims(ignore_mask, axis=1)
            D_label = np.ones(ignore_mask.shape) * label
            D_label[ignore_mask] = ignore_label
            D_label = torch.tensor(D_label,
                                   dtype=torch.float,
                                   device=torch_device)

            return D_label

        h, w = map(int, eval_crop_size.split(','))
        eval_crop_size = (h, w)

        h, w = map(int, crop_size.split(','))
        crop_size = (h, w)

        # create network
        if arch == 'deeplab2':
            model = Res_Deeplab(num_classes=ds.num_classes)
        elif arch == 'unet_resnet50':
            model = unet_resnet50(num_classes=ds.num_classes)
        elif arch == 'resnet101_deeplabv3':
            model = resnet101_deeplabv3(num_classes=ds.num_classes)
        else:
            print('Architecture {} not supported'.format(arch))
            return

        # load pretrained parameters
        if restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(restore_from)
        else:
            saved_state_dict = torch.load(restore_from)

        # only copy the params that exist in current model (caffe-like)
        new_params = model.state_dict().copy()
        for name, param in new_params.items():
            if name in saved_state_dict and param.size(
            ) == saved_state_dict[name].size():
                new_params[name].copy_(saved_state_dict[name])
        model.load_state_dict(new_params)

        model.train()
        model = model.to(torch_device)

        # init D
        model_D = FCDiscriminator(num_classes=ds.num_classes)
        if restore_from_d is not None:
            model_D.load_state_dict(torch.load(restore_from_d))
        model_D.train()
        model_D = model_D.to(torch_device)

        print('Built model')

        if snapshot_dir is not None:
            if not os.path.exists(snapshot_dir):
                os.makedirs(snapshot_dir)

        ds_train_xy = ds.train_xy(crop_size=crop_size,
                                  scale=random_scale,
                                  mirror=random_mirror,
                                  range01=model.RANGE01,
                                  mean=model.MEAN,
                                  std=model.STD)
        ds_train_y = ds.train_y(crop_size=crop_size,
                                scale=random_scale,
                                mirror=random_mirror,
                                range01=model.RANGE01,
                                mean=model.MEAN,
                                std=model.STD)
        ds_val_xy = ds.val_xy(crop_size=eval_crop_size,
                              scale=False,
                              mirror=False,
                              range01=model.RANGE01,
                              mean=model.MEAN,
                              std=model.STD)

        train_dataset_size = len(ds_train_xy)

        if partial_data_size != -1:
            if partial_data_size > partial_data_size:
                print('partial-data-size > |train|: exiting')
                return

        if partial_data == 1.0 and (partial_data_size == -1 or
                                    partial_data_size == train_dataset_size):
            trainloader = data.DataLoader(ds_train_xy,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=5,
                                          pin_memory=True)

            trainloader_gt = data.DataLoader(ds_train_y,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=5,
                                             pin_memory=True)

            trainloader_remain = None
            print('|train|={}'.format(train_dataset_size))
            print('|val|={}'.format(len(ds_val_xy)))
        else:
            #sample partial data
            if partial_data_size != -1:
                partial_size = partial_data_size
            else:
                partial_size = int(partial_data * train_dataset_size)

            if partial_id is not None:
                train_ids = pickle.load(open(partial_id))
                print('loading train ids from {}'.format(partial_id))
            else:
                rng = np.random.RandomState(random_seed)
                train_ids = list(rng.permutation(train_dataset_size))

            if snapshot_dir is not None:
                pickle.dump(train_ids,
                            open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb'))

            print('|train supervised|={}'.format(partial_size))
            print('|train unsupervised|={}'.format(train_dataset_size -
                                                   partial_size))
            print('|val|={}'.format(len(ds_val_xy)))

            print('supervised={}'.format(list(train_ids[:partial_size])))

            train_sampler = data.sampler.SubsetRandomSampler(
                train_ids[:partial_size])
            train_remain_sampler = data.sampler.SubsetRandomSampler(
                train_ids[partial_size:])
            train_gt_sampler = data.sampler.SubsetRandomSampler(
                train_ids[:partial_size])

            trainloader = data.DataLoader(ds_train_xy,
                                          batch_size=batch_size,
                                          sampler=train_sampler,
                                          num_workers=3,
                                          pin_memory=True)
            trainloader_remain = data.DataLoader(ds_train_xy,
                                                 batch_size=batch_size,
                                                 sampler=train_remain_sampler,
                                                 num_workers=3,
                                                 pin_memory=True)
            trainloader_gt = data.DataLoader(ds_train_y,
                                             batch_size=batch_size,
                                             sampler=train_gt_sampler,
                                             num_workers=3,
                                             pin_memory=True)

            trainloader_remain_iter = enumerate(trainloader_remain)

        testloader = data.DataLoader(ds_val_xy,
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        print('Data loaders ready')

        trainloader_iter = enumerate(trainloader)
        trainloader_gt_iter = enumerate(trainloader_gt)

        # implement model.optim_parameters(args) to handle different models' lr setting

        # optimizer for segmentation network
        optimizer = optim.SGD(model.optim_parameters(learning_rate),
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
        optimizer.zero_grad()

        # optimizer for discriminator network
        optimizer_D = optim.Adam(model_D.parameters(),
                                 lr=learning_rate_d,
                                 betas=(0.9, 0.99))
        optimizer_D.zero_grad()

        # loss/ bilinear upsampling
        bce_loss = BCEWithLogitsLoss2d()

        print('Built optimizer')

        # labels for adversarial training
        pred_label = 0
        gt_label = 1

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_mask_accum = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        t1 = time.time()

        print('Training for {} steps...'.format(num_steps))
        for i_iter in range(num_steps + 1):

            model.train()
            model.freeze_batchnorm()

            optimizer.zero_grad()
            adjust_learning_rate(optimizer, i_iter)
            optimizer_D.zero_grad()
            adjust_learning_rate_D(optimizer_D, i_iter)

            for sub_i in range(iter_size):

                # train G

                if not supervised:
                    # don't accumulate grads in D
                    for param in model_D.parameters():
                        param.requires_grad = False

                # do semi first
                if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \
                        trainloader_remain is not None:
                    try:
                        _, batch = next(trainloader_remain_iter)
                    except:
                        trainloader_remain_iter = enumerate(trainloader_remain)
                        _, batch = next(trainloader_remain_iter)

                    # only access to img
                    images, _, _, _ = batch
                    images = images.float().to(torch_device)

                    pred = model(images)
                    pred_remain = pred.detach()

                    D_out = model_D(F.softmax(pred, dim=1))
                    D_out_sigmoid = F.sigmoid(
                        D_out).data.cpu().numpy().squeeze(axis=1)

                    ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                        np.bool)

                    loss_semi_adv = lambda_semi_adv * bce_loss(
                        D_out, make_D_label(gt_label, ignore_mask_remain))
                    loss_semi_adv = loss_semi_adv / iter_size

                    #loss_semi_adv.backward()
                    loss_semi_adv_value += float(
                        loss_semi_adv) / lambda_semi_adv

                    if lambda_semi <= 0 or i_iter < semi_start:
                        loss_semi_adv.backward()
                        loss_semi_value = 0
                    else:
                        # produce ignore mask
                        semi_ignore_mask = (D_out_sigmoid < mask_t)

                        semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                        semi_gt[semi_ignore_mask] = ignore_label

                        semi_ratio = 1.0 - float(
                            semi_ignore_mask.sum()) / semi_ignore_mask.size

                        loss_semi_mask_accum += float(semi_ratio)

                        if semi_ratio == 0.0:
                            loss_semi_value += 0
                        else:
                            semi_gt = torch.FloatTensor(semi_gt)

                            loss_semi = lambda_semi * loss_calc(pred, semi_gt)
                            loss_semi = loss_semi / iter_size
                            loss_semi_value += float(loss_semi) / lambda_semi
                            loss_semi += loss_semi_adv
                            loss_semi.backward()

                else:
                    loss_semi = None
                    loss_semi_adv = None

                # train with source

                try:
                    _, batch = next(trainloader_iter)
                except:
                    trainloader_iter = enumerate(trainloader)
                    _, batch = next(trainloader_iter)

                images, labels, _, _ = batch
                images = images.float().to(torch_device)
                ignore_mask = (labels.numpy() == ignore_label)
                pred = model(images)

                loss_seg = loss_calc(pred, labels)

                if supervised:
                    loss = loss_seg
                else:
                    D_out = model_D(F.softmax(pred, dim=1))

                    loss_adv_pred = bce_loss(
                        D_out, make_D_label(gt_label, ignore_mask))

                    loss = loss_seg + lambda_adv_pred * loss_adv_pred
                    loss_adv_pred_value += float(loss_adv_pred) / iter_size

                # proper normalization
                loss = loss / iter_size
                loss.backward()
                loss_seg_value += float(loss_seg) / iter_size

                if not supervised:
                    # train D

                    # bring back requires_grad
                    for param in model_D.parameters():
                        param.requires_grad = True

                    # train with pred
                    pred = pred.detach()

                    if d_remain:
                        pred = torch.cat((pred, pred_remain), 0)
                        ignore_mask = np.concatenate(
                            (ignore_mask, ignore_mask_remain), axis=0)

                    D_out = model_D(F.softmax(pred, dim=1))
                    loss_D = bce_loss(D_out,
                                      make_D_label(pred_label, ignore_mask))
                    loss_D = loss_D / iter_size / 2
                    loss_D.backward()
                    loss_D_value += float(loss_D)

                    # train with gt
                    # get gt labels
                    try:
                        _, batch = next(trainloader_gt_iter)
                    except:
                        trainloader_gt_iter = enumerate(trainloader_gt)
                        _, batch = next(trainloader_gt_iter)

                    _, labels_gt, _, _ = batch
                    D_gt_v = one_hot(labels_gt)
                    ignore_mask_gt = (labels_gt.numpy() == ignore_label)

                    D_out = model_D(D_gt_v)
                    loss_D = bce_loss(D_out,
                                      make_D_label(gt_label, ignore_mask_gt))
                    loss_D = loss_D / iter_size / 2
                    loss_D.backward()
                    loss_D_value += float(loss_D)

            optimizer.step()
            optimizer_D.step()

            sys.stdout.write('.')
            sys.stdout.flush()

            if i_iter % eval_every == 0 and i_iter != 0:
                model.eval()
                with torch.no_grad():
                    evaluator = EvaluatorIoU(ds.num_classes)
                    for index, batch in enumerate(testloader):
                        image, label, size, name = batch
                        size = size[0].numpy()
                        image = image.float().to(torch_device)
                        output = model(image)
                        output = output.cpu().data[0].numpy()

                        output = output[:, :size[0], :size[1]]
                        gt = np.asarray(label[0].numpy()[:size[0], :size[1]],
                                        dtype=np.int)

                        output = output.transpose(1, 2, 0)
                        output = np.asarray(np.argmax(output, axis=2),
                                            dtype=np.int)

                        evaluator.sample(gt, output, ignore_value=ignore_label)

                        sys.stdout.write('+')
                        sys.stdout.flush()

                per_class_iou = evaluator.score()
                mean_iou = per_class_iou.mean()

                loss_seg_value /= eval_every
                loss_adv_pred_value /= eval_every
                loss_D_value /= eval_every
                loss_semi_mask_accum /= eval_every
                loss_semi_value /= eval_every
                loss_semi_adv_value /= eval_every

                sys.stdout.write('\n')

                t2 = time.time()

                print(
                    'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}'
                    .format(i_iter, num_steps, t2 - t1, loss_seg_value,
                            loss_adv_pred_value, loss_D_value,
                            loss_semi_mask_accum, loss_semi_value,
                            loss_semi_adv_value))

                for i, (class_name,
                        iou) in enumerate(zip(ds.class_names, per_class_iou)):
                    print('class {:2d} {:12} IU {:.2f}'.format(
                        i, class_name, iou))

                print('meanIOU: ' + str(mean_iou) + '\n')

                loss_seg_value = 0
                loss_adv_pred_value = 0
                loss_D_value = 0
                loss_semi_value = 0
                loss_semi_mask_accum = 0
                loss_semi_adv_value = 0

                t1 = t2

            if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D.state_dict(),
                    osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

        if snapshot_dir is not None:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = DeeplabMulti(num_classes=args.num_classes)
    #model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from, map_location='cuda:0')

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)
    #summary(model,(3,7,7))

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    #summary(model_D, (21,321,321))
    #quit()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                      args.iter_size * args.batch_size,
                                      scale=args.random_scale)
    train_dataset_size = len(train_dataset)
    train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                         args.iter_size * args.batch_size,
                                         scale=args.random_scale)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             sampler=train_remain_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         sampler=train_gt_sampler,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0
        loss_laplacian = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                try:
                    pred = interp(model(images))
                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        print("WARNING: out of memory")
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise exception

                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv / args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(
                        semi_ignore_mask.sum()) / semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(
                            pred, semi_gt, args.gpu)
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            try:
                pred = interp(model(images))
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception
            for i in range(1):
                imagess = torch.zeros(1280, 720).cuda()
                for j in range(19):
                    try:
                        imagess += pred[i, j, :, :].reshape(1280, 720)
                    except IndexError:
                        pass
                try:
                    label = labels[i, :, :].reshape(1280, 720).cuda()
                except IndexError:
                    pass
                imagess = torch.from_numpy(
                    cv2.Laplacian(imagess.cpu().detach().numpy(), -1)).cuda()
                labell = torch.from_numpy(
                    cv2.Laplacian(label.cpu().detach().numpy(), -1)).cuda()
                imagess = imagess.reshape(1, 1, 1280, 720)
                labell = labell.reshape(1, 1, 1280, 720)
                l = bce_loss(imagess, labell)

            loss_laplacian = l
            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred - loss_laplacian

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.__next__()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.__next__()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}, loss_laplacian = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value, loss_laplacian))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth'))
            #torch.cuda.empty_cache()

    end = timeit.default_timer()
    print(end - start, 'seconds')
コード例 #12
0
def main(args):
    """Create the model and start the training."""

    mkdir_check(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.snapshot_dir, 'tb-logs'))

    h, w = map(int, args.src_input_size.split(','))
    src_input_size = (h, w)

    h, w = map(int, args.tgt_input_size.split(','))
    tgt_input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=19)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
            ListDataSet(args.src_data_dir, args.src_img_list, args.src_lbl_list, 
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=src_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_list, args.tgt_lbl_list,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_nolabel = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_nolabel_list, None,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    targetloader_nolabel_iter = enumerate(targetloader_nolabel)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(src_input_size[1], src_input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(tgt_input_size[1], tgt_input_size[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_tgt_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(args, optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(args, optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            #pred1 = interp(pred1)
            pred2 = interp(pred2)

            #loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 #+ args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target seg

            _, batch = targetloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            #loss_tgt_seg1 = loss_calc(pred_target1, labels, args.gpu)
            loss_tgt_seg2 = loss_calc(pred_target2, labels, args.gpu)
            loss = loss_tgt_seg2 #+ args.lambda_seg * loss_tgt_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward(retain_graph=True)
            loss_tgt_seg_value2 += loss_tgt_seg2.data.cpu().numpy() / args.iter_size

            # train with target_nolabel adv
            _, batch = targetloader_nolabel_iter.__next__()
            images, _, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()

            D_out2 = model_D2(F.softmax(pred2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D2.step()

        print(
        'iter = {:5d}/{:8d}, loss_seg2 = {:.3f} loss_tgt_seg2 = {:.3f} loss_adv2 = {:.3f} loss_D2 = {:.3f}'.format(
            i_iter, args.num_steps_stop, loss_seg_value2, loss_tgt_seg_value2,
            loss_adv_target_value2, loss_D_value2))

        writer.add_scalars('loss/seg', {
            'src2': loss_seg_value2, 
            'tgt2': loss_tgt_seg_value2,
            }, i_iter)

        writer.add_scalar('loss/adv', loss_adv_target_value2, i_iter)
        writer.add_scalar('loss/d', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
コード例 #13
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D parameters
        restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D)
        model_D.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    """
    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        adv_loss_value = 0
        d_loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate(optimizer_D, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            """
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False
            """

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # size (1, 19, 720, 1280)
            # pdb.set_trace()

            # feature = nn.Softmax(dim=1)(pred1)
            # softmax_out = nn.Softmax(dim=1)(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            _, batch = targetloader_iter.__next__()

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=False)

            images, _, _ = batch
            images = images.to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred_target1, pred_target2 = model(images)

            # pred_target1, 2 == [1, 19, 91, 161]
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            # pred_target1, 2 == [1, 19, 720, 1280]
            # pdb.set_trace()

            # feature_target = nn.Softmax(dim=1)(pred_target1)
            # softmax_out_target = nn.Softmax(dim=1)(pred_target2)

            # features = torch.cat((pred1, pred_target1), dim=0)
            # outputs = torch.cat((pred2, pred_target2), dim=0)
            # features.size() == [2, 19, 720, 1280]
            # softmax_out.size() == [2, 19, 720, 1280]
            # pdb.set_trace()
            # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None)
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')
            dc_source = torch.FloatTensor(
                D_out_target.size()).fill_(0).to(device)
            # pdb.set_trace()
            adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source)
            adv_loss = adv_loss / args.iter_size
            adv_loss = args.lambda_adv * adv_loss
            # pdb.set_trace()
            # classifier_loss = nn.BCEWithLogitsLoss()(pred2,
            #        torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda())
            # pdb.set_trace()
            adv_loss.backward()
            adv_loss_value += adv_loss.item()
            # optimizer_D.step()
            #TODO: normalize loss?

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=True)

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            D_out = CDAN([F.softmax(pred1), F.softmax(pred2)],
                         model_D,
                         cdan_implement='concat')

            dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device)
            # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None)
            d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')

            dc_target = torch.FloatTensor(
                D_out_target.size()).fill_(1).to(device)
            d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            continue

        optimizer.step()
        optimizer_D.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'generator_loss': adv_loss_value,
            'discriminator_loss': d_loss_value,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    adv_loss_value, d_loss_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)
            save_path = args.snapshot_dir + '/eval_{}'.format(i_iter)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            # evaluate(args, save_path, args.snapshot_dir, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
                i_iter)
            pdb.set_trace()
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
コード例 #14
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #     model_D1.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D1.pth'))
    #     model_D2.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D2.pth'))

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Load VGG
    #vgg19 = torchvision.models.vgg19(pretrained=True)
    #vgg19.to(device)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(0, args.num_steps):

        loss_seg_value = 0
        #         loss_seg_local_value = 0
        loss_adv_target_value = 0
        #         loss_adv_local_value = 0
        loss_D_value = 0
        #         loss_D_local_value = 0
        loss_local_match_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred_s1, pred_s2, _, _ = model(images)
            #f_s2 = normalize(f_s2)
            pred_s1, pred_s2 = interp(pred_s1), interp(pred_s2)

            loss_seg = args.lambda_seg * seg_loss(pred_s1, labels) + seg_loss(
                pred_s2, labels)
            del labels

            # proper normalization
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target
            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_t1, pred_t2, _, _ = model(images)
            #f_t2 = normalize(f_t2)
            pred_t1, pred_t2 = interp_target(pred_t1), interp_target(pred_t2)
            del images

            D_out_1 = model_D1(F.softmax(pred_t1, dim=1))
            D_out_2 = model_D2(F.softmax(pred_t2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out_1,
                torch.FloatTensor(
                    D_out_1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = bce_loss(
                D_out_2,
                torch.FloatTensor(
                    D_out_2.data.size()).fill_(source_label).to(device))
            loss_adv_target_value += (
                args.lambda_adv_target1 * loss_adv_target1 +
                args.lambda_adv_target *
                loss_adv_target2).item() / args.iter_size

            loss = loss_seg + args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2

            del D_out_1, D_out_2

            #             #< Local patch part>#
            #             corres_id2 = get_correspondance(f_s2, f_t2, pred_s2, pred_t2)

            #             #loss_local1 = local_feature_loss(corres_id1, f_s1, f_t1, model, seg_loss)
            #             loss_local2 = local_feature_loss(corres_id2, labels, f_t2, model, seg_loss)
            #             loss_local = args.lambda_match_target2 * loss_local2 #+args.lambda_match_target1 * loss_local1
            #             loss += loss_local
            #             if corres_id2.nelement() > 0:
            #                 loss_local_match_value += loss_local.item()/ args.iter_size

            loss /= args.iter_size
            loss.backward()

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_s1)), model_D2(
                F.softmax(pred_s2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

            # train with target
            pred_t1, pred_t2 = pred_t1.detach(), pred_t2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_t1)), model_D2(
                F.softmax(pred_t2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 1000 == 0:
            val_dir = '../dataset/Cityscapes/leftImg8bit_trainvaltest/'
            val_list = './dataset/cityscapes_list/val.txt'
            save_dir = './results/tmp'
            gt_dir = '../dataset/Cityscapes/gtFine_trainvaltest/gtFine/val'
            evaluate_cityscapes.test_model(model, device, val_dir, val_list,
                                           save_dir)
            mIoU = compute_iou.mIoUforTest(gt_dir, save_dir)

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                #'loss_seg_local': loss_seg_local_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_local_match': loss_local_match_value,
                'loss_D': loss_D_value,
                'mIoU': mIoU
                #'loss_D_local': loss_D_local_value
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f}, loss_local_match = {5:.3f}, mIoU = {6:3f} '
        #'loss_seg_local = {5:.3f} loss_adv_local = {6:.3f}, loss_D_local = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_local_match_value, mIoU)
                    #loss_seg_local_value, loss_adv_local_value, loss_D_local_value)
        )

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
コード例 #15
0
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network

    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters (weights)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(
            args.restore_from
        )  ## http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth
    else:
        saved_state_dict = torch.load(args.restore_from)
        #checkpoint = torch.load(args.restore_from)_

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()  # state_dict() is current model
    for name, param in new_params.items():
        #print (name) # 'conv1.weight, name:param(value), dict
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            #print('copy {}'.format(name))
    model.load_state_dict(new_params)
    #model.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(args.checkpoint['optim_dict'])

    model.train(
    )  # https://pytorch.org/docs/stable/nn.html, Sets the module in training mode.
    model.cuda(args.gpu)  ##

    cudnn.benchmark = True  # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware

    # init D

    model_D = FCDiscriminator(num_classes=args.num_classes)
    #args.restore_from_D = 'snapshots/linear2/VOC_25000_D.pth'
    if args.restore_from_D is not None:  # None
        model_D.load_state_dict(torch.load(args.restore_from_D))
        # checkpoint_D = torch.load(args.restore_from_D)
        # model_D.load_state_dict(checkpoint_D['state_dict'])
        # optimizer_D.load_state_dict(checkpoint_D['optim_dict'])
    model_D.train()
    model_D.cuda(args.gpu)

    if USECALI:
        model_cali = ModelWithTemperature(model, model_D)
        model_cali.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)
    train_dataset_remain = VOCDataSet(args.data_dir,
                                      args.data_list_remain,
                                      crop_size=input_size,
                                      scale=args.random_scale,
                                      mirror=args.random_mirror,
                                      mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_dataset_size_remain = len(train_dataset_remain)

    print train_dataset_size
    print train_dataset_size_remain

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)
    if args.partial_data is None:  #if not partial, load all

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        #args.partial_data = 0.125
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:  #args.partial_id is none
            train_ids = range(train_dataset_size)
            train_ids_remain = range(train_dataset_size_remain)
            np.random.shuffle(train_ids)  #shuffle!
            np.random.shuffle(train_ids_remain)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'),
                         'wb'))  #randomly suffled ids

        #sampler
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:])  # 0~1/8,
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids_remain[:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:])
        # train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) # 0~1/8
        # train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) # used as unlabeled, 7/8
        # train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        #train loader
        trainloader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            num_workers=3,
            pin_memory=True)  # multi-process data loading
        trainloader_remain = data.DataLoader(train_dataset_remain,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        # trainloader_remain = data.DataLoader(train_dataset,
        #                                      batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3,
        #                                     pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network

    # model.optim_paramters(args) = list(dict1, dict2), dict1 >> 'lr' and 'params'
    # print(type(model.optim_parameters(args)[0]['params'])) # generator
    #print(model.state_dict()['coeff'][0]) #confirmed

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    #optimizer.add_param_group({"params":model.coeff}) # assign new coefficient to the optimizer
    #print(len(optimizer.param_groups))
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()  #initialize

    if USECALI:
        optimizer_cali = optim.LBFGS([model_cali.temperature],
                                     lr=0.01,
                                     max_iter=50)
        optimizer_cali.zero_grad()

        nll_criterion = BCEWithLogitsLoss().cuda()  # BCE!!
        ece_criterion = ECELoss().cuda()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(
        size=(input_size[1], input_size[0]), mode='bilinear'
    )  # okay it automatically change to functional.interpolate
    # 321, 321

    if version.parse(torch.__version__) >= version.parse('0.4.0'):  #0.4.1
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    semi_ratio_sum = 0
    semi_sum = 0
    loss_seg_sum = 0
    loss_adv_sum = 0
    loss_vat_sum = 0
    l_seg_sum = 0
    l_vat_sum = 0
    l_adv_sum = 0
    logits_list = []
    labels_list = []

    #https: // towardsdatascience.com / understanding - pytorch -with-an - example - a - step - by - step - tutorial - 81fc5f8c4e8e

    for i_iter in range(args.num_steps):

        loss_seg_value = 0  # L_seg
        loss_adv_pred_value = 0  # 0.01 L_adv
        loss_D_value = 0  # L_D
        loss_semi_value = 0  # 0.1 L_semi
        loss_semi_adv_value = 0  # 0.001 L_adv
        loss_vat_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)  #changing lr by iteration

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            ###################### train G!!!###########################
            ############################################################
            # don't accumulate grads in D

            for param in model_D.parameters(
            ):  # <class 'torch.nn.parameter.Parameter'>, convolution weights
                param.requires_grad = False  # do not update gradient of D (freeze) while G

            ######### do unlabeled first!! 0.001 L_adv + 0.1 L_semi ###############

            # lambda_semi, lambda_adv for unlabeled
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.next(
                    )  #remain = unlabeled
                    print(trainloader_remain_iter.next())
                except:
                    trainloader_remain_iter = enumerate(
                        trainloader_remain)  # impose counters
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch  # <class 'torch.Tensor'>
                images = Variable(images).cuda(
                    args.gpu)  # <class 'torch.Tensor'>

                pred = interp(
                    model(images))  # S(X), pred <class 'torch.Tensor'>
                pred_remain = pred.detach(
                )  #use detach() when attempting to remove a tensor from a computation graph, will be used for D
                # https://discuss.pytorch.org/t/clone-and-detach-in-v0-4-0/16861
                # The difference is that detach refers to only a given variable on which it's called.
                # torch.no_grad affects all operations taking place within the with statement. >> for context,
                # requires_grad is for tensor

                # pred >> (8,21,321,321), L_adv
                D_out = interp(
                    model_D(F.softmax(pred))
                )  # D(S(X)), confidence, 8,1,321,321, not detached, there was not dim
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)  # (8,321,321) 0~1

                # 0.001 L_adv!!!!
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)  # no ignore_mask for unlabeled adv
                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label,
                                        ignore_mask_remain))  #gt_label =1,
                # -log(D(S(X)))
                loss_semi_adv = loss_semi_adv / args.iter_size  #normalization

                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                ##--- visualization, pred(8,21,321,321), D_out_sigmoid(8,321,321)
                """
                if i_iter % 1000 == 0:
                    vpred = pred.transpose(1, 2).transpose(2, 3).contiguous()  # (8,321,321,21)
                    vpred = vpred.view(-1, 21)  # (8*321*321, 21)
                    vlogsx = F.log_softmax(vpred)  # torch.Tensor
                    vsemi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    vsemi_gt = Variable(torch.FloatTensor(vsemi_gt).long()).cuda(gpu)
                    vlogsx = vlogsx.gather(1, vsemi_gt.view(-1, 1))
                    sx = F.softmax(vpred).gather(1, vsemi_gt.view(-1, 1))
                    vD_out_sigmoid = Variable(torch.FloatTensor(D_out_sigmoid)).cuda(gpu).view(-1, 1)
                    vlogsx = (vlogsx*(2.5*vD_out_sigmoid+0.5))
                    vlogsx = -vlogsx.squeeze(dim=1)
                    sx = sx.squeeze(dim=1)
                    vD_out_sigmoid = vD_out_sigmoid.squeeze(dim=1)
                    dsx = vD_out_sigmoid.data.cpu().detach().numpy()
                    vlogsx = vlogsx.data.cpu().detach().numpy()
                    sx = sx.data.cpu().detach().numpy()
                    plt.clf()
                    plt.figure(figsize=(15, 5))
                    plt.subplot(131)
                    plt.ylim(0, 0.004)
                    plt.scatter(dsx, vlogsx, s = 0.1)  # variable requires grad cannot call numpy >> detach
                    plt.xlabel('D(S(X))')
                    plt.ylabel('Loss_Semi per Pixel')
                    plt.subplot(132)
                    plt.scatter(dsx, vlogsx, s = 0.1)  # variable requires grad cannot call numpy >> detach
                    plt.xlabel('D(S(X))')
                    plt.ylabel('Loss_Semi per Pixel')
                    plt.subplot(133)
                    plt.scatter(dsx, sx, s=0.1)
                    plt.xlabel('D(S(X))')
                    plt.ylabel('S(x)')
                    plt.savefig('/home/eungyo/AdvSemiSeg/plot/'  + str(i_iter) + '.png')
                    """

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:

                    semi_gt = pred.data.cpu().numpy().argmax(
                        axis=1
                    )  # pred=S(X) ((8,21,321,321)), semi_gt is not one-hot, 8,321,321
                    #(8, 321, 321)

                    if not USECALI:
                        semi_ignore_mask = (
                            D_out_sigmoid < args.mask_T
                        )  # both (8,321,321) 0~1threshold!, numpy
                        semi_gt[
                            semi_ignore_mask] = 255  # Yhat, ignore pixel becomes 255
                        semi_ratio = 1.0 - float(semi_ignore_mask.sum(
                        )) / semi_ignore_mask.size  # ignored pixels / H*W
                        print('semi ratio: {:.4f}'.format(semi_ratio))

                        if semi_ratio == 0.0:
                            loss_semi_value += 0
                        else:
                            semi_gt = torch.FloatTensor(semi_gt)
                            confidence = torch.FloatTensor(
                                D_out_sigmoid)  ## added, only pred is on cuda
                            loss_semi = args.lambda_semi * weighted_loss_calc(
                                pred, semi_gt, args.gpu, confidence)

                    else:
                        semi_ratio = 1
                        semi_gt = (torch.FloatTensor(semi_gt))  # (8,321,321)
                        confidence = torch.FloatTensor(
                            F.sigmoid(
                                model_cali.temperature_scale(D_out.view(
                                    -1))).data.cpu().numpy())  # (8*321*321,)
                        loss_semi = args.lambda_semi * calibrated_loss_calc(
                            pred, semi_gt, args.gpu, confidence, accuracies,
                            n_bin
                        )  #  L_semi = Yhat * log(S(X)) # loss_calc(pred, semi_gt, args.gpu)
                        # pred(8,21,321,321)

                    if semi_ratio != 0:
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi

                        if args.method == 'vatent' or args.method == 'vat':
                            #v_loss = vat_loss(model, images, pred, eps=args.epsilon[i])  # R_vadv
                            weighted_v_loss = weighted_vat_loss(
                                model,
                                images,
                                pred,
                                confidence,
                                eps=args.epsilon)

                            if args.method == 'vatent':
                                #v_loss += entropy_loss(pred)  # R_cent (conditional entropy loss)
                                weighted_v_loss += weighted_entropy_loss(
                                    pred, confidence)

                            v_loss = weighted_v_loss / args.iter_size
                            loss_vat_value += v_loss.data.cpu().numpy()
                            loss_semi_adv += args.alpha * v_loss

                            loss_vat_sum += loss_vat_value
                            if i_iter % 100 == 0 and sub_i == 4:
                                l_vat_sum = loss_vat_sum / 100
                                if i_iter == 0:
                                    l_vat_sum = l_vat_sum * 100
                                loss_vat_sum = 0

                        loss_semi += loss_semi_adv
                        loss_semi.backward(
                        )  # 0.001 L_adv + 0.1 L_semi, backward == back propagation

            else:
                loss_semi = None
                loss_semi_adv = None

            ###########train with source (labeled data)############### L_ce + 0.01 * L_adv

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)  # safe coding
                _, batch = trainloader_iter.next()  #counter, batch

            images, labels, _, _ = batch  # also get labels images(8,321,321)
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (
                labels.numpy() == 255
            )  # ignored pixels == 255 >> 1, yes ignored mask for labeled data

            pred = interp(model(images))  # S(X), 8,21,321,321
            loss_seg = loss_calc(pred, labels,
                                 args.gpu)  # -Y*logS(X)= L_ce, not detached

            if USED:
                softsx = F.softmax(pred, dim=1)
                D_out = interp(model_D(softsx))  # D(S(X)), L_adv

                loss_adv_pred = bce_loss(
                    D_out, make_D_label(
                        gt_label,
                        ignore_mask))  # both  8,1,321,321, gt_label = 1
                # L_adv =  -log(D(S(X)), make_D_label is all 1 except ignored_region

                loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
                if USECALI:
                    if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                        ) and i_iter >= args.semi_start_adv:
                        with torch.no_grad():
                            _, prediction = torch.max(softsx, 1)
                            labels_mask = (
                                (labels > 0) *
                                (labels != 255)) | (prediction.data.cpu() > 0)
                            labels = labels[labels_mask]
                            prediction = prediction[labels_mask]
                            fake_mask = (labels.data.cpu().numpy() !=
                                         prediction.data.cpu().numpy())
                            real_label = make_conf_label(
                                1, fake_mask
                            )  # (10*321*321, ) 0 or 1 (fake or real)

                            logits = D_out.squeeze(dim=1)
                            logits = logits[labels_mask]
                            logits_list.append(logits)  # initialize
                            labels_list.append(real_label)

                        if (i_iter * args.iter_size * args.batch_size + sub_i +
                                1) % train_dataset_size == 0:
                            logits = torch.cat(logits_list).cuda(
                            )  # overall 5000 images in val,  #logits >> 5000,100, (1464*321*321,)
                            labels = torch.cat(labels_list).cuda()
                            before_temperature_nll = nll_criterion(
                                logits, labels).item()  ####modify
                            before_temperature_ece, _, _ = ece_criterion(
                                logits, labels)  # (1464*321*321,)
                            before_temperature_ece = before_temperature_ece.item(
                            )
                            print('Before temperature - NLL: %.3f, ECE: %.3f' %
                                  (before_temperature_nll,
                                   before_temperature_ece))

                            def eval():
                                loss_cali = nll_criterion(
                                    model_cali.temperature_scale(logits),
                                    labels)
                                loss_cali.backward()
                                return loss_cali

                            optimizer_cali.step(
                                eval)  # just one backward >> not 50 iterations
                            after_temperature_nll = nll_criterion(
                                model_cali.temperature_scale(logits),
                                labels).item()
                            after_temperature_ece, accuracies, n_bin = ece_criterion(
                                model_cali.temperature_scale(logits), labels)
                            after_temperature_ece = after_temperature_ece.item(
                            )
                            print('Optimal temperature: %.3f' %
                                  model_cali.temperature.item())
                            print(
                                'After temperature - NLL: %.3f, ECE: %.3f' %
                                (after_temperature_nll, after_temperature_ece))

                            logits_list = []
                            labels_list = []

            else:
                loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_sum += loss_seg / args.iter_size
            if USED:
                loss_adv_sum += loss_adv_pred
            if i_iter % 100 == 0 and sub_i == 4:
                l_seg_sum = loss_seg_sum / 100
                if USED:
                    l_adv_sum = loss_adv_sum / 100
                if i_iter == 0:
                    l_seg_sum = l_seg_sum * 100
                    l_adv_sum = l_adv_sum * 100
                loss_seg_sum = 0
                loss_adv_sum = 0

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            if USED:
                loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
                ) / args.iter_size

            ##################### train D!!!###########################
            ###########################################################
            # bring back requires_grad
            if USED:
                for param in model_D.parameters():
                    param.requires_grad = True  # before False.

                ############# train with pred S(X)############# labeled + unlabeled
                pred = pred.detach(
                )  #orginally only use labeled data, freeze S(X) when train D,

                # We do train D with the unlabeled data. But the difference is quite small

                if args.D_remain:  #default true
                    pred = torch.cat(
                        (pred, pred_remain), 0
                    )  # pred_remain(unlabeled S(x)) is detached  16,21,321,321
                    ignore_mask = np.concatenate(
                        (ignore_mask, ignore_mask_remain),
                        axis=0)  # 16,321,321

                D_out = interp(
                    model_D(F.softmax(pred, dim=1))
                )  # D(S(X)) 16,1,321,321  # softmax(pred,dim=1) for 0.4, not nessesary

                loss_D = bce_loss(D_out,
                                  make_D_label(pred_label,
                                               ignore_mask))  # pred_label = 0

                # -log(1-D(S(X)))
                loss_D = loss_D / args.iter_size / 2  # iter_size = 1, /2 because there is G and D
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()

                ################## train with gt################### only labeled
                #VOCGT and VOCdataset can be reduced to one dataset in this repo.
                # get gt labels Y
                #print "before train gt"
                try:
                    print(trainloader_gt_iter.next())  # len 732
                    _, batch = trainloader_gt_iter.next()

                except:
                    trainloader_gt_iter = enumerate(trainloader_gt)
                    _, batch = trainloader_gt_iter.next()
                #print "train with gt?"
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)  #one_hot
                ignore_mask_gt = (labels_gt.numpy() == 255
                                  )  # same as ignore_mask (8,321,321)
                #print "finish"
                D_out = interp(model_D(D_gt_v))  # D(Y)
                loss_D = bce_loss(D_out,
                                  make_D_label(gt_label,
                                               ignore_mask_gt))  # log(D(Y))
                loss_D = loss_D / args.iter_size / 2

                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()

        if USED:
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))  #snapshot
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_semi_adv = {6:.3f}, loss_vat = {7: .5f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value, loss_vat_value))
        #                                        L_ce             L_adv for labeled      L_D                L_semi              L_adv for unlabeled
        #loss_adv should be inversely proportional to the loss_D if they are seeing the same data.
        # loss_adv_p is essentially the inverse loss of loss_D. We expect them to achieve a good balance during the adversarial training
        # loss_D is around 0.2-0.5   >> good
        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar'))
            #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar'))
            break

        if i_iter % 100 == 0 and sub_i == 4:  #loss_seg_value
            wdata = "iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.8f}, loss_semi_adv = {6:.3f}, l_vat_sum = {7: .5f}, loss_label = {8: .4}\n".format(
                i_iter, args.num_steps, l_seg_sum, l_adv_sum, loss_D_value,
                loss_semi_value, loss_semi_adv_value, l_vat_sum,
                l_seg_sum + 0.01 * l_adv_sum)
            #wdata2 = "{0:8d} {1:s} {2:s} {3:s} {4:s} {5:s} {6:s} {7:s} {8:s}\n".format(i_iter,str(model.coeff[0])[8:14],str(model.coeff[1])[8:14],str(model.coeff[2])[8:14],str(model.coeff[3])[8:14],str(model.coeff[4])[8:14],str(model.coeff[5])[8:14],str(model.coeff[6])[8:14],str(model.coeff[7])[8:14])
            if i_iter == 0:
                f2 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'w')
                f2.write(wdata)
                f2.close()
                #f3 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'w')
                #f3.write(wdata2)
                #f3.close()
            else:
                f1 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'a')
                f1.write(wdata)
                f1.close()
                #f4 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'a')
                #f4.write(wdata2)
                #f4.close()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:  # 5000
            print('taking snapshot ...')
            #state = {'epoch':i_iter, 'state_dict':model.state_dict(),'optim_dict':optimizer.state_dict()}
            #state_D = {'epoch':i_iter, 'state_dict': model_D.state_dict(), 'optim_dict': optimizer_D.state_dict()}
            #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar'))
            #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
コード例 #16
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #### restore model_D1, D2 and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D1 parameters
        restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D1)
        model_D1.load_state_dict(saved_state_dict)

        # model_D2 parameters
        restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D2)
        model_D2.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            pdb.set_trace()
            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)
            pdb.set_trace()
            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            pdb.set_trace()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'loss_adv_target1': loss_adv_target_value1,
            'loss_adv_target2': loss_adv_target_value2,
            'loss_D1': loss_D_value1,
            'loss_D2': loss_D_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
コード例 #17
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    tau = torch.ones(1) * args.tau
    tau = tau.cuda(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params, False)
    elif args.model == 'DeepLabVGG':
        model = DeeplabVGG(pretrained=True, num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    weak_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Resize(1024),
        transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #         RandomCrop(768)
    ])

    target_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Normalize(mean, std)
        #         transforms.Resize(1024),
        #         transforms.ToTensor(),
        #         RandomCrop(768)
    ])

    label_set = GTA5(
        root=args.data_dir,
        num_cls=19,
        split='all',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size,
        #              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )
    unlabel_set = Cityscapes(
        root=args.data_dir_target,
        split=args.set,
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )

    test_set = Cityscapes(
        root=args.data_dir_target,
        split='val',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                       crop_transform=RandomCrop(768)
    )

    label_loader = data.DataLoader(label_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=False)

    unlabel_loader = data.DataLoader(unlabel_set,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=False)

    test_loader = data.DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=False)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    [model, model_D2,
     model_D2], [optimizer, optimizer_D1, optimizer_D2
                 ] = amp.initialize([model, model_D2, model_D2],
                                    [optimizer, optimizer_D1, optimizer_D2],
                                    opt_level="O1",
                                    num_losses=7)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = Interpolate(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = Interpolate(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_test = Interpolate(size=(input_size_target[1],
                                    input_size_target[0]),
                              mode='bilinear',
                              align_corners=True)
    #     interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True)

    normalize_transform = transforms.Compose([
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    # labels for adversarial training
    source_label = 0
    target_label = 1

    max_mIoU = 0

    total_loss_seg_value1 = []
    total_loss_adv_target_value1 = []
    total_loss_D_value1 = []
    total_loss_con_value1 = []

    total_loss_seg_value2 = []
    total_loss_adv_target_value2 = []
    total_loss_D_value2 = []
    total_loss_con_value2 = []

    hist = np.zeros((num_cls, num_cls))

    #     for i_iter in range(args.num_steps):
    for i_iter, (batch, batch_un) in enumerate(
            zip(roundrobin_infinite(label_loader),
                roundrobin_infinite(unlabel_loader))):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_con_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_con_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D1.parameters():
            param.requires_grad = False

        for param in model_D2.parameters():
            param.requires_grad = False

        # train with source

        images, labels = batch
        images_orig = images
        images = transform_batch(images, normalize_transform)
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

#         loss.backward()
        loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

        # train with target

        images_tar, labels_tar = batch_un
        images_tar_orig = images_tar
        images_tar = transform_batch(images_tar, normalize_transform)
        images_tar = Variable(images_tar).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_tar)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_adv_target1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv_target2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()
        loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
        ) / args.iter_size
        loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
        ) / args.iter_size

        # train with consistency loss
        # unsupervise phase
        policies = RandAugment().get_batch_policy(args.batch_size)
        rand_p1 = np.random.random(size=args.batch_size)
        rand_p2 = np.random.random(size=args.batch_size)
        random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2])

        images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1,
                                      rand_p2, random_dir)

        images_aug_orig = images_aug
        images_aug = transform_batch(images_aug, normalize_transform)
        images_aug = Variable(images_aug).cuda(args.gpu)

        pred_target_aug1, pred_target_aug2 = model(images_aug)
        pred_target_aug1 = interp_target(pred_target_aug1)
        pred_target_aug2 = interp_target(pred_target_aug2)

        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1)
        max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1)

        psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32)
        psuedo_label1_thre = psuedo_label1.copy()
        psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies,
                                             rand_p1, rand_p2, random_dir)
        psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32)
        psuedo_label2_thre = psuedo_label2.copy()
        psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies,
                                             rand_p1, rand_p2, random_dir)

        psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu)
        psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu)

        if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre,
                                  args.gpu)
            loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size
        else:
            loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre,
                                  args.gpu)
            loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size
        else:
            loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2
        # proper normalization
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()

# train D

# bring back requires_grad
        for param in model_D1.parameters():
            param.requires_grad = True

        for param in model_D2.parameters():
            param.requires_grad = True

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        D_out1 = model_D1(F.softmax(pred1, dim=1))
        D_out2 = model_D2(F.softmax(pred2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        # train with target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_con_value1,
                    loss_con_value2))

        total_loss_seg_value1.append(loss_seg_value1)
        total_loss_adv_target_value1.append(loss_adv_target_value1)
        total_loss_D_value1.append(loss_D_value1)
        total_loss_con_value1.append(loss_con_value1)

        total_loss_seg_value2.append(loss_seg_value2)
        total_loss_adv_target_value2.append(loss_adv_target_value2)
        total_loss_D_value2.append(loss_D_value2)
        total_loss_con_value2.append(loss_con_value2)

        hist += fast_hist(
            labels.cpu().numpy().flatten().astype(int),
            torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int),
            num_cls)

        if i_iter % 10 == 0:
            print('({}/{})'.format(i_iter + 1, int(args.num_steps)))
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(np.mean(iu)))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_train_acc = acc_overall
            avg_train_loss_seg1 = np.mean(total_loss_seg_value1)
            avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1)
            avg_train_loss_dis1 = np.mean(total_loss_D_value1)
            avg_train_loss_con1 = np.mean(total_loss_con_value1)
            avg_train_loss_seg2 = np.mean(total_loss_seg_value2)
            avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2)
            avg_train_loss_dis2 = np.mean(total_loss_D_value2)
            avg_train_loss_con2 = np.mean(total_loss_con_value2)

            print('avg_train_acc      :', avg_train_acc)
            print('avg_train_loss_seg1 :', avg_train_loss_seg1)
            print('avg_train_loss_adv1 :', avg_train_loss_adv1)
            print('avg_train_loss_dis1 :', avg_train_loss_dis1)
            print('avg_train_loss_con1 :', avg_train_loss_con1)
            print('avg_train_loss_seg2 :', avg_train_loss_seg2)
            print('avg_train_loss_adv2 :', avg_train_loss_adv2)
            print('avg_train_loss_dis2 :', avg_train_loss_dis2)
            print('avg_train_loss_con2 :', avg_train_loss_con2)

            writer['train'].add_scalar('log/mIoU', mIoU, i_iter)
            writer['train'].add_scalar('log/acc', avg_train_acc, i_iter)
            writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2,
                                       i_iter)

            hist = np.zeros((num_cls, num_cls))
            total_loss_seg_value1 = []
            total_loss_adv_target_value1 = []
            total_loss_D_value1 = []
            total_loss_con_value1 = []
            total_loss_seg_value2 = []
            total_loss_adv_target_value2 = []
            total_loss_D_value2 = []
            total_loss_con_value2 = []

            fig = plt.figure(figsize=(15, 15))

            labels = labels[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(331)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(337)
            images = images_orig[0].cpu().numpy().transpose((1, 2, 0))
            #             images += IMG_MEAN
            ax.imshow(images)
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(334)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            labels_tar = labels_tar[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(332)
            ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L')))
            ax.axis("off")
            ax.set_title('tar_labels')

            ax = fig.add_subplot(338)
            ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('tar_datas')

            _, pred_target2 = torch.max(pred_target2, dim=1)
            pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(335)
            ax.imshow(print_palette(
                Image.fromarray(pred_target2).convert('L')))
            ax.axis("off")
            ax.set_title('tar_predicts')

            print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0],
                  'random_dir', random_dir[0])

            psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(333)
            ax.imshow(
                print_palette(
                    Image.fromarray(psuedo_label2_thre).convert('L')))
            ax.axis("off")
            ax.set_title('psuedo_labels')

            ax = fig.add_subplot(339)
            ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('aug_datas')

            _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1)
            pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(336)
            ax.imshow(
                print_palette(Image.fromarray(pred_target_aug2).convert('L')))
            ax.axis("off")
            ax.set_title('aug_predicts')

            #             plt.show()
            writer['train'].add_figure('image/',
                                       fig,
                                       global_step=i_iter,
                                       close=True)

        if i_iter % 500 == 0:
            loss1 = []
            loss2 = []
            for test_i, batch in enumerate(test_loader):

                images, labels = batch
                images_orig = images
                images = transform_batch(images, normalize_transform)
                images = Variable(images).cuda(args.gpu)

                pred1, pred2 = model(images)
                pred1 = interp_test(pred1)
                pred1 = pred1.detach()
                pred2 = interp_test(pred2)
                pred2 = pred2.detach()

                loss_seg1 = loss_calc(pred1, labels, args.gpu)
                loss_seg2 = loss_calc(pred2, labels, args.gpu)
                loss1.append(loss_seg1.item())
                loss2.append(loss_seg2.item())

                hist += fast_hist(
                    labels.cpu().numpy().flatten().astype(int),
                    torch.argmax(pred2,
                                 dim=1).cpu().numpy().flatten().astype(int),
                    num_cls)

            print('test')
            fig = plt.figure(figsize=(15, 15))
            labels = labels[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(311)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(313)
            ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(312)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            #             plt.show()

            writer['test'].add_figure('test_image/',
                                      fig,
                                      global_step=i_iter,
                                      close=True)

            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(mIoU))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_test_loss1 = np.mean(loss1)
            avg_test_loss2 = np.mean(loss2)
            avg_test_acc = acc_overall
            print('avg_test_loss2 :', avg_test_loss1)
            print('avg_test_loss1 :', avg_test_loss2)
            print('avg_test_acc   :', avg_test_acc)
            writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter)
            writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter)
            writer['test'].add_scalar('log/acc', avg_test_acc, i_iter)
            writer['test'].add_scalar('log/mIoU', mIoU, i_iter)

            hist = np.zeros((num_cls, num_cls))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if max_mIoU < mIoU:
            max_mIoU = mIoU
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
コード例 #18
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    torch.cuda.set_device(args.gpu)

    ############################
    #validation data
    testloader = data.DataLoader(cityscapesDataSet(args.data_dir_target,
                                                   args.data_list_target_val,
                                                   crop_size=input_size_target,
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set_val),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list = './dataset/cityscapes_list/label.txt'
    gt_imgs = open(label_path_list, 'r').read().splitlines()
    gt_imgs = [
        osp.join('./data/Cityscapes/data/gtFine/val', x) for x in gt_imgs
    ]

    interp_val = nn.UpsamplingBilinear2d(size=(com_size[1], com_size[0]))

    ############################

    # Create network
    #  if args.model == 'DeepLab':
    #      model = Res_Deeplab(num_classes=args.num_classes)
    #      if args.restore_from[:4] == 'http' :
    #          saved_state_dict = model_zoo.load_url(args.restore_from)
    #      else:
    #           saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage.cuda(args.gpu))
    #
    #      new_params = model.state_dict().copy()
    #      for i in saved_state_dict:
    #          # Scale.layer5.conv2d_list.3.weight
    #          i_parts = i.split('.')
    #          # print i_parts
    #          if not args.num_classes == 19 or not i_parts[1] == 'layer5':
    #              new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    #              # print i_parts
    #      model.load_state_dict(new_params)

    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        saved_state_dict = torch.load(
            args.restore_from,
            map_location=lambda storage, loc: storage.cuda(args.gpu))
        model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1
    AvePool = torch.nn.AvgPool2d(kernel_size=(512, 1024))

    for i_iter in range(args.num_steps):
        model.train()
        loss_lse_target_two_value = 0
        loss_lse_target_value = 0
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, class_label_source, mask_weakly, _, name = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.item() / args.iter_size
            loss_seg_value2 += loss_seg2.data.item() / args.iter_size

            # train with target

            _, batch = next(targetloader_iter)
            images, class_label, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            class_label_target_lse = class_label.type(torch.FloatTensor)
            exp_target = torch.min(
                torch.exp(1 * pred_target2),
                Variable(torch.exp(torch.tensor(40.0))).cuda(args.gpu))
            lse = (1.0 / 1) * torch.log(AvePool(exp_target))
            loss_lse_target = bce_loss(
                lse,
                Variable(class_label_target_lse.reshape(lse.size())).cuda(
                    args.gpu))

            #    exp_target = torch.min(torch.exp(1*pred_target1), Variable(torch.exp(torch.tensor(40.0))).cuda(args.gpu))
            #    lse  = (1.0/1) * torch.log(AvePool(exp_target))
            #    loss_lse_target_two = bce_loss(lse, Variable(class_label_target_lse.reshape(lse.size())).cuda(args.gpu))

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 + 0.2 * loss_lse_target  # + 0.02 * loss_lse_target_two
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.item(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.item(
            ) / args.iter_size
            loss_lse_target_value += loss_lse_target.data.item(
            ) / args.iter_size
            #   loss_lse_target_two_value += loss_lse_target_two.data.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))
            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.item()
            loss_D_value2 += loss_D2.data.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.item()
            loss_D_value2 += loss_D2.data.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()
        del D_out1, D_out2, pred1, pred2, pred_target1, pred_target2, images, labels

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_lse_target = {8:.3f} loss_lse_target2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_lse_target_value,
                    loss_lse_target_two_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(),
                       osp.join(args.snapshot_dir, 'GTA5' + '.pth'))
            torch.save(model_D2.state_dict(),
                       osp.join(args.snapshot_dir, 'GTA5_' + '_D2.pth'))
            torch.save(model_D1.state_dict(),
                       osp.join(args.snapshot_dir, 'GTA5_' + '_D1.pth'))

            hist = np.zeros((19, 19))
            #       model.cuda(0)
            model.eval()
            f = open(args.results_dir, 'a')
            for index, batch in enumerate(testloader):
                print(index)
                image, _, _, name = batch
                output1, output2 = model(
                    Variable(image, volatile=True).cuda(args.gpu))
                pred = interp_val(output2)
                pred = pred[0].permute(1, 2, 0)
                pred = torch.max(pred, 2)[1].byte()
                pred_cpu = pred.data.cpu().numpy()
                del pred, output1, output2
                label = Image.open(gt_imgs[index])
                label = np.array(label.resize(com_size, Image.NEAREST))
                label = label_mapping(label, mapping)
                hist += fast_hist(label.flatten(), pred_cpu.flatten(), 19)
    #      model.cuda(args.gpu)
            mIoUs = per_class_iu(hist)
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            print(mIoU)
            f.write('i_iter:{:d},        miou:{:0.5f} \n'.format(i_iter, mIoU))
            f.close()
コード例 #19
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        #model = DeeplabMulti(num_classes=args.num_classes)
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=16)
        '''
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #restore(model, saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not i_parts[0] == 'layer4' and not i_parts[0] == 'fc':
                #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                new_params[i] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
        '''
    else:
        raise NotImplementedError

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)

    # if args.restore_from_D[:4] == 'http':
    #     saved_state_dict = model_zoo.load_url(args.restore_from_D)
    # else:
    #     saved_state_dict = torch.load(args.restore_from_D)
    #     ### for running different versions of pytorch
    # model_dict = model_D1.state_dict()
    # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    # model_dict.update(saved_state_dict)
    # model_D1.load_state_dict(saved_state_dict)

    model_D1.train()
    model_D1.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_loader = data_loader(args)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss()

    interp = nn.Upsample(size=(416, 416), mode='bilinear', align_corners=True)
    #interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training

    # set up tensorboard
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    count = args.start_count  # 迭代次数
    for dat in train_loader:
        if count > args.num_steps:
            break

        loss_seg_value1_anchor = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, count)

        optimizer_D1.zero_grad()

        adjust_learning_rate_D(optimizer_D1, count)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            # 相当于group=0时,训练样本对应的类有15类为[0,1,2,3,4,5,6,7,8,9,10,....],验证集有5类,
            # 现在从训练集类中随机选择两类,然后从其中一类中选择两张图片,对应为基准图片和正样本图片,
            # 两者属于同一类,接着从另一类中选择一张图片作为负样本,属于不同类。其中基准图片对应的是查询集图片
            #############################
            anchor_img, anchor_mask, pos_img, pos_mask, neg_img, neg_mask = dat  # 返回的是基准图片以及mask,正样本以及mask(和基准图片属于同一类),负样本以及mask(和基准图片属于不同类)

            anchor_img, anchor_mask, pos_img, pos_mask, \
                = anchor_img.cuda(), anchor_mask.cuda(), pos_img.cuda(), pos_mask.cuda()  # [1, 3, 386, 500],[1, 386, 500],[1, 3, 374, 500],[1, 374, 500]

            anchor_mask = torch.unsqueeze(anchor_mask,
                                          dim=1)  # [1, 1, 386, 500]
            pos_mask = torch.unsqueeze(pos_mask, dim=1)  # [1,1, 374, 500]
            samples = torch.cat([pos_img, anchor_img], 0)

            pred = model(samples, pos_mask)  ##[2, 2, 53, 53],#[2, 2, 53, 53]
            pred = interp(pred)

            loss_seg1_anchor = seg_loss(
                pred,
                anchor_mask.squeeze().unsqueeze(0).long())
            D_out1 = model_D1(F.softmax(pred))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(D_out1.data.size()).fill_(1).to(
                    device))  # 相当于将源域的标签设置为1,然后判断判别网络得到的目标预测与源域对应的损失
            '''
            s = torch.stack([s, 1-s])
            loss_s = seg_loss()
            '''
            loss = loss_seg1_anchor + args.lambda_adv_target1 * loss_adv_target1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1_anchor += loss_seg1_anchor.item() / args.iter_size
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D# bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            # train with anchor
            pred_target1 = pred.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(D_out1.data.size()).fill_(0).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with GT
            anchor_gt = Variable(one_hot(anchor_mask)).cuda()
            D_out1 = model_D1(anchor_gt)
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(D_out1.data.size()).fill_(1).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        count = count + 1
        if args.tensorboard:
            scalar_info = {
                'loss_seg1_anchor': loss_seg_value1_anchor,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_D1': loss_D_value1,
            }

            if count % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, count)

        # print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
            .format(count, args.num_steps, loss_seg_value1_anchor,
                    loss_adv_target_value1, loss_D_value1))

        if count >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'voc2012_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'voc2012_' + str(args.num_steps_stop) + '_D1.pth'))
            break

        if count % args.save_pred_every == 0 and count != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'voc2012_' + str(count) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'voc2012_' + str(count) + '_D1.pth'))

    if args.tensorboard:
        writer.close()
コード例 #20
0
def main():
    """Create the model and start the training."""
    global args
    args = get_arguments()
    if args.dist:
        init_dist(args.launcher, backend=args.backend)
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'Deeplab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from, strict=False)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict, strict=False)
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict, strict=False)
            del saved_state_dict

    model.train()
    model.to(device)
    if args.dist:
        broadcast_params(model)

    if rank == 0:
        print(model)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)
    if args.dist:
        broadcast_params(model_D1)
    if args.restore_D is not None:
        D_dict = torch.load(args.restore_D)
        model_D1.load_state_dict(D_dict, strict=False)
        del D_dict

    model_D2.train()
    model_D2.to(device)
    if args.dist:
        broadcast_params(model_D2)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_data = GTA5BDDDataSet(args.data_dir,
                                args.data_list,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size,
                                scale=args.random_scale,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)
    trainloader = data.DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False if train_sampler else True,
                                  num_workers=args.num_workers,
                                  pin_memory=False,
                                  sampler=train_sampler)

    trainloader_iter = enumerate(cycle(trainloader))

    target_data = BDDDataSet(args.data_dir_target,
                             args.data_list_target,
                             max_iters=args.num_steps * args.iter_size *
                             args.batch_size,
                             crop_size=input_size_target,
                             scale=False,
                             mirror=args.random_mirror,
                             mean=IMG_MEAN,
                             set=args.set)
    target_sampler = None
    if args.dist:
        target_sampler = DistributedSampler(target_data)
    targetloader = data.DataLoader(target_data,
                                   batch_size=args.batch_size,
                                   shuffle=False if target_sampler else True,
                                   num_workers=args.num_workers,
                                   pin_memory=False,
                                   sampler=target_sampler)

    targetloader_iter = enumerate(cycle(targetloader))

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard and rank == 0:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    torch.cuda.empty_cache()
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, size, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            interp = nn.Upsample(size=(size[1], size[0]),
                                 mode='bilinear',
                                 align_corners=True)

            pred1 = model(images)
            pred1 = interp(pred1)

            loss_seg1 = seg_loss(pred1, labels)

            loss = loss_seg1

            # proper normalization
            loss = loss / args.iter_size / world_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size

            _, batch = targetloader_iter.__next__()
            # train with target
            images, _, _ = batch
            images = images.to(device)

            pred_target1 = model(images)
            pred_target1 = interp_target(pred_target1)

            D_out1 = model_D1(F.softmax(pred_target1))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1
            loss = loss / args.iter_size / world_size

            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            D_out1 = model_D1(F.softmax(pred1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with target
            pred_target1 = pred_target1.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            if args.dist:
                average_gradients(model)
                average_gradients(model_D1)
                average_gradients(model_D2)

            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        if rank == 0:
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1 * world_size,
                    'loss_D2': loss_D_value2 * world_size,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D1.pth'))
    print(args.snapshot_dir)
    if args.tensorboard and rank == 0:
        writer.close()
コード例 #21
0
def main():
    """Create the model and start the training."""
    model_num = 0  # The number of model (for saving models)

    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    random.seed(args.random_seed)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(log_dir=args.snapshot_dir)

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    cudnn.benchmark = True

    # init G
    if args.model == 'DeepLab':
        if args.training_option == 1:
            model = Res_Deeplab(num_classes=args.num_classes,
                                num_layers=args.num_layers,
                                dropout=args.dropout,
                                after_layer=args.after_layer)
        elif args.training_option == 2:
            model = Res_Deeplab2(num_classes=args.num_classes)
        '''elif args.training_option == 3:
            model = Res_Deeplab50(num_classes=args.num_classes)'''

        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for k, v in saved_state_dict.items():
            print(k)

        for k in new_params:
            print(k)

        for i in saved_state_dict:
            i_parts = i.split('.')
            if '.'.join(i_parts[args.i_parts_index:]) in new_params:
                print("Restored...")
                if args.not_restore_last == True:
                    if not i_parts[
                            args.i_parts_index] == 'layer5' and not i_parts[
                                args.i_parts_index] == 'layer6':
                        new_params['.'.join(i_parts[args.i_parts_index:]
                                            )] = saved_state_dict[i]
                else:
                    new_params['.'.join(
                        i_parts[args.i_parts_index:])] = saved_state_dict[i]

        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes,
                               extra_layers=args.extra_discriminator_layers)
    model_D2 = FCDiscriminator(num_classes=args.num_classes, extra_layers=0)

    model_D1.train()
    model_D1.cuda(args.gpu)
    model_D2.train()
    model_D2.cuda(args.gpu)

    trainloader = data.DataLoader(sourceDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=False,
        random_flip=args.augment_1,
        random_lighting=args.augment_1,
        random_blur=args.augment_1,
        random_scaling=args.augment_1,
        mean=IMG_MEAN_SOURCE,
        ignore_label=args.ignore_label),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    trainloader2 = data.DataLoader(sourceDataSet(
        args.data_dir2,
        args.data_list2,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        random_rotate=False,
        random_flip=args.augment_2,
        random_lighting=args.augment_2,
        random_blur=args.augment_2,
        random_scaling=args.augment_2,
        mean=IMG_MEAN_SOURCE2,
        ignore_label=args.ignore_label),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    trainloader_iter2 = enumerate(trainloader2)

    if args.num_of_targets > 1:

        IMG_MEAN_TARGET1 = np.array(
            (101.41694189393208, 89.68194541655483, 77.79408426901315),
            dtype=np.float32)  # crowdai all BGR

        targetloader1 = data.DataLoader(isprsDataSet(
            args.data_dir_target1,
            args.data_list_target1,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            scale=False,
            mean=IMG_MEAN_TARGET1,
            ignore_label=args.ignore_label),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

        targetloader_iter1 = enumerate(targetloader1)

    targetloader = data.DataLoader(isprsDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mean=IMG_MEAN_TARGET,
        ignore_label=args.ignore_label),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    valloader = data.DataLoader(valDataSet(args.data_dir_val,
                                           args.data_list_val,
                                           crop_size=input_size_target,
                                           mean=IMG_MEAN_TARGET,
                                           scale=args.val_scale,
                                           mirror=False),
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True)

    # implement model.optim_parameters(args) to handle different models' lr setting
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()
    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.weighted_loss == True:
        bce_loss = torch.nn.BCEWithLogitsLoss()
    else:
        bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[0],
                                      input_size_target[1]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # Which layers to freeze
    non_trainable(args.dont_train, model)

    # List saving all best 5 mIoU's
    best_mIoUs = [0.0, 0.0, 0.0, 0.0, 0.0]

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ################################## train G #################################

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False
            for param in model_D2.parameters():
                param.requires_grad = False

            ################################## train with source #################################

            while True:
                try:
                    _, batch = next(
                        trainloader_iter)  # Cityscapes, only discriminator1
                    images, labels, _, train_name = batch
                    images = Variable(images).cuda(args.gpu)

                    _, batch = next(
                        trainloader_iter2
                    )  # Main (airsim) discriminator2 and final output
                    images2, labels2, size, train_name2 = batch
                    images2 = Variable(images2).cuda(args.gpu)

                    pred1, _ = model(images)
                    pred1 = interp(pred1)

                    _, pred2 = model(images2)
                    pred2 = interp(pred2)

                    loss_seg1 = loss_calc(pred1, labels, args.gpu,
                                          args.ignore_label, train_name,
                                          weights1)
                    loss_seg2 = loss_calc(pred2, labels2, args.gpu,
                                          args.ignore_label, train_name2,
                                          weights2)

                    loss = loss_seg2 + args.lambda_seg * loss_seg1

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()

                    if isinstance(loss_seg1.data.cpu().numpy(), list):
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value1 += loss_seg1.data.cpu().numpy(
                        ) / args.iter_size

                    if isinstance(loss_seg2.data.cpu().numpy(), list):
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        )[0] / args.iter_size
                    else:
                        loss_seg_value2 += loss_seg2.data.cpu().numpy(
                        ) / args.iter_size

                    break
                except (RuntimeError, AssertionError, AttributeError):
                    continue

            ###################################################################################################

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)

            if args.num_of_targets > 1:
                _, batch1 = next(targetloader_iter1)
                images1, _, _ = batch1
                images1 = Variable(images1).cuda(args.gpu)

                pred_target1, _ = model(images1)

            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            ################################## train with target #################################
            if args.adv_option == 1 or args.adv_option == 3:

                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label)).cuda(
                                args.gpu))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label)).cuda(
                                args.gpu))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size
                loss.backward()

                if isinstance(loss_adv_target1.data.cpu().numpy(), list):
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy() / args.iter_size

                if isinstance(loss_adv_target2.data.cpu().numpy(), list):
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy() / args.iter_size

            ###################################################################################################
            if args.adv_option == 2 or args.adv_option == 3:
                pred1, _ = model(images)
                pred1 = interp(pred1)
                _, pred2 = model(images2)
                pred2 = interp(pred2)
                '''pred1 = pred1.detach()
                pred2 = pred2.detach()'''

                D_out1 = model_D1(F.softmax(pred1, dim=1))
                D_out2 = model_D2(F.softmax(pred2, dim=1))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(target_label)).cuda(
                                args.gpu))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(target_label)).cuda(
                                args.gpu))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size
                loss.backward()

                if isinstance(loss_adv_target1.data.cpu().numpy(), list):
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value1 += loss_adv_target1.data.cpu(
                    ).numpy() / args.iter_size

                if isinstance(loss_adv_target2.data.cpu().numpy(), list):
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy()[0] / args.iter_size
                else:
                    loss_adv_target_value2 += loss_adv_target2.data.cpu(
                    ).numpy() / args.iter_size

            ###################################################################################################
            ################################## train D #################################

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            ################################## train with source #################################
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

            ################################# train with target #################################

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list):
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else:
                loss_D_value1 += loss_D1.data.cpu().numpy()

            if isinstance(loss_D2.data.cpu().numpy(), list):
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else:
                loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            if model_num != args.num_models_keep:
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'model_' + str(model_num) + '_D2.pth'))
                model_num = model_num + 1
            if model_num == args.num_models_keep:
                model_num = 0

        # Validation
        if (i_iter % args.val_every == 0 and i_iter != 0) or i_iter == 1:
            mIoU = validation(valloader, model, interp_target, writer, i_iter,
                              [37, 41, 10])
            for i in range(0, len(best_mIoUs)):
                if best_mIoUs[i] < mIoU:
                    torch.save(
                        model.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '.pth'))
                    torch.save(
                        model_D1.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '_D1.pth'))
                    torch.save(
                        model_D2.state_dict(),
                        osp.join(args.snapshot_dir,
                                 'bestmodel_' + str(i) + '_D2.pth'))
                    best_mIoUs.append(mIoU)
                    print("Saved model at iteration %d as the best %d" %
                          (i_iter, i))
                    best_mIoUs.sort(reverse=True)
                    best_mIoUs = best_mIoUs[:5]
                    break

        # Save for tensorboardx
        writer.add_scalar('loss_seg_value1', loss_seg_value1, i_iter)
        writer.add_scalar('loss_seg_value2', loss_seg_value2, i_iter)
        writer.add_scalar('loss_adv_target_value1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar('loss_adv_target_value2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar('loss_D_value1', loss_D_value1, i_iter)
        writer.add_scalar('loss_D_value2', loss_D_value2, i_iter)

    writer.close()
コード例 #22
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           vgg16_caffe_path='./model/vgg16_init.pth',
                           pretrained=True)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=2048).to(device)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=1024).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feature, prediction = model(images)
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward()
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feature_target, _ = model(images)
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward()
        loss_adv_target_value = loss_adv_target.item()

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    if args.tensorboard:
        writer.close()
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)

        if args.restore_from is not None:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    # print i_parts
            model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SyntheticSmokeTrain(
        args={},
        dataset_limit=args.num_steps * args.iter_size * args.batch_size,
        image_shape=input_size,
        dataset_mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    targetloader = data.DataLoader(SimpleSmokeVal(args={},
                                                  image_size=input_size_target,
                                                  dataset_mean=IMG_MEAN),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
        # bce_loss_all = torch.nn.BCEWithLogitsLoss(reduction='none')
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
        # bce_loss_all = torch.nn.MSELoss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    # interp_domain = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            for param in interp_domain.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.detach().data.cpu().item(
            ) / args.iter_size
            loss_seg_value2 += loss_seg2.detach().data.cpu().item(
            ) / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            # w1 = torch.argmax(pred_target1.detach(), dim=1)
            # w2 = torch.argmax(pred_target2.detach(), dim=1)

            min_class1 = sorted([(k, v)
                                 for k, v in Counter(w1.ravel()).items()],
                                key=lambda x: x[1])[0][0]
            min_class2 = sorted([(k, v)
                                 for k, v in Counter(w2.ravel()).items()],
                                key=lambda x: x[1])[0][0]

            # m1 = torch.where(w1==min_class1)
            # m1c = torch.where(w1!=min_class1)
            # w1[m1] = 11
            # w1[m1c] = 1

            # m2 = torch.where(w2==min_class2)
            # m2c = torch.where(w2!=min_class2)
            # w2[m2] = 11
            # w2[m2c] = 1

            # D_out1 = interp_domain(D_out1)
            # D_out2 = interp_domain(D_out2)

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.detach().data.cpu(
            ).item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.detach().data.cpu(
            ).item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D2.pth'))
        writer.flush()
コード例 #24
0
ファイル: after.py プロジェクト: Euiyeon-Kim/DLOW-Pytorch
def main():
    '''
        Create the model and start the training.
    '''

    # Device 설정
    device = torch.device("cuda" if not args.cpu else "cpu")

    # Source와 Target 모두 1280 * 720으로 resizing
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)
    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # 모델 생성
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :                        # 미리 학습된 weight를 다운로드
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:                                                       # pth 파일을 직접 설정할 경우
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # Discriminator 생성
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _, domainess = batch
            adw = torch.sqrt(1-domainess).float()
            adw.requires_grad = False
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))


            loss_adv_target1 = adw*bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = adw*bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
コード例 #25
0
ファイル: train.py プロジェクト: MrtBian/AdvSemiSeg
def main():
    # 将参数的input_size 映射到整数,并赋值,从字符串转换到整数二元组
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = False
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # 确保模型中参数的格式与要加载的参数相同
    # 返回一个字典,保存着module的所有状态(state);parameters和persistent buffers都会包含在字典中,字典的key就是parameter和buffer的 names。
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        # print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print('copy {}'.format(name))
    model.load_state_dict(new_params)

    # 设置为训练模式
    model.train()
    cudnn.benchmark = True

    model.cuda(gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                               scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))  # ?
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))  # 写入文件

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')  # ???

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        print("Iter:", i_iter)
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = next(trainloader_remain_iter)
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = next(trainloader_remain_iter)

                # only access to img
                images, _, _, _ = batch

                images = Variable(images).cuda(gpu)
                # images = Variable(images).cpu()

                pred = interp(model(images))
                D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

                # produce ignore mask
                semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt[semi_ignore_mask] = 255

                semi_ratio = 1.0 - float(semi_ignore_mask.sum()) / semi_ignore_mask.size
                print('semi ratio: {:.4f}'.format(semi_ratio))

                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)

                    loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
                    loss_semi = loss_semi / args.iter_size
                    loss_semi.backward()
                    loss_semi_value += loss_semi.data.cpu().numpy()[0] / args.lambda_semi
            else:
                loss_semi = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch

            images = Variable(images).cuda(gpu)
            # images = Variable(images).cpu()

            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # D_gt_v = Variable(one_hot(labels_gt)).cpu()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}'.format(
                i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1).to(device)
    model_D1.train()
    model_D1.apply(weights_init)
    # model_D1.load_state_dict(torch.load(args.D_restore_from))

    # create optimizer
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)  # 整个模型的优化器
    optimizer.zero_grad()
    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    # uneccessery
    bce_loss = torch.nn.BCEWithLogitsLoss()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label, edge_image = data_batch[
                'sal_image'], data_batch['sal_label'], data_batch['edge_image']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue
            sal_image, sal_label, edge_image = Variable(sal_image), Variable(
                sal_label), Variable(edge_image)
            sal_image, sal_label, edge_image = sal_image.to(
                device), sal_label.to(device), edge_image.to(device)

            sal_pred = model(sal_image)
            edge_pred = model(edge_image)

            # train G(with G)
            for param in model_D1.parameters():
                param.requires_grad = False

            sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            sD_out = model_D1(edge_pred)
            # 这里用的是bceloss 训练G的时候,target判别为sourse_label时损失函数低
            loss_adv_target1 = bce_loss(
                sD_out,
                torch.FloatTensor(sD_out.data.size()).fill_(salLabel).to(
                    device))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size)
            loss_adv_target_value1 += sd_loss.data  # 记录专用

            sd_loss = sd_loss * args.lambda_adv_target2
            sd_loss.backward()

            # train D
            for param in model_D1.parameters():
                param.requires_grad = True

            sal_pred = sal_pred.detach()
            edge_pred = edge_pred.detach()

            ss_out = model_D1(sal_pred)
            ss_loss = bce_loss(
                ss_out,
                torch.FloatTensor(
                    ss_out.data.size()).fill_(salLabel).to(device))
            ss_Loss = ss_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += ss_Loss.data

            ss_Loss.backward()

            se_out = model_D1(edge_pred)
            se_loss = bce_loss(
                se_out,
                torch.FloatTensor(
                    se_out.data.size()).fill_(edgeLabel).to(device))
            se_Loss = se_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += se_Loss.data

            se_Loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                optimizer_D1.step()
                optimizer_D1.zero_grad()
                aveGrad = 0

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1 = 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(
                    model.state_dict(),
                    osp.join(
                        args.snapshot_dir,
                        'sal_' + str(i_epotch) + '_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(
                        args.snapshot_dir, 'sal_' + str(i_epotch) + '_' +
                        str(i_iter) + '_D1.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            args.learning_rate_D = args.learning_rate_D * 0.1
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          model.parameters()),
                                   lr=args.learning_rate,
                                   weight_decay=args.weight_decay)
            optimizer_D1 = optim.Adam(model_D1.parameters(),
                                      lr=args.learning_rate_D,
                                      betas=(0.9, 0.99))

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')
コード例 #27
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")
    cudnn.benchmark = True
    cudnn.enabled = True

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    Iter = 0
    bestIoU = 0

    # Create network
    # init G
    if args.model == 'DeepLab':
        model = DeeplabMultiFeature(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        if args.continue_train:
            if list(saved_state_dict.keys())[0].split('.')[0] == 'module':
                for key in saved_state_dict.keys():
                    saved_state_dict['.'.join(
                        key.split('.')[1:])] = saved_state_dict.pop(key)
            model.load_state_dict(saved_state_dict)
        else:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes).to(device)

    if args.continue_train:
        model_weights_path = args.restore_from
        temp = model_weights_path.split('.')
        temp[-2] = temp[-2] + '_D'
        model_D_weights_path = '.'.join(temp)
        model_D.load_state_dict(torch.load(model_D_weights_path))
        temp = model_weights_path.split('.')
        temp = temp[-2][-9:]
        Iter = int(temp.split('_')[1]) + 1

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # init data loader
    if args.data_dir.split('/')[-1] == 'gta5_deeplab':
        trainset = GTA5DataSet(args.data_dir,
                               args.data_list,
                               max_iters=args.num_steps * args.iter_size *
                               args.batch_size,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)
    elif args.data_dir.split('/')[-1] == 'syn_deeplab':
        trainset = synthiaDataSet(args.data_dir,
                                  args.data_list,
                                  max_iters=args.num_steps * args.iter_size *
                                  args.batch_size,
                                  crop_size=input_size,
                                  scale=args.random_scale,
                                  mirror=args.random_mirror,
                                  mean=IMG_MEAN)

    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    # init optimizer
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level="O2",
                                      keep_batchnorm_fp32=True,
                                      loss_scale="dynamic")

    model_D, optimizer_D = amp.initialize(model_D,
                                          optimizer_D,
                                          opt_level="O2",
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")

    # init loss
    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)
    L1_loss = torch.nn.L1Loss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    test_interp = nn.Upsample(size=(1024, 2048),
                              mode='bilinear',
                              align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # init prototype
    num_prototype = args.num_prototype
    num_ins = args.num_prototype * 10
    src_cls_features = torch.zeros([len(BG_LABEL), num_prototype, 2048],
                                   dtype=torch.float32).to(device)
    src_cls_ptr = np.zeros(len(BG_LABEL), dtype=np.uint64)
    src_ins_features = torch.zeros([len(FG_LABEL), num_ins, 2048],
                                   dtype=torch.float32).to(device)
    src_ins_ptr = np.zeros(len(FG_LABEL), dtype=np.uint64)

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        writer = SummaryWriter(args.log_dir)

    # start training
    for i_iter in range(Iter, args.num_steps):

        loss_seg_value = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cls_value = 0
        loss_ins_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D

            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            src_feature, pred = model(images)
            pred_softmax = F.softmax(pred, dim=1)
            pred_idx = torch.argmax(pred_softmax, dim=1)

            right_label = F.interpolate(labels.unsqueeze(0).float(),
                                        (pred_idx.size(1), pred_idx.size(2)),
                                        mode='nearest').squeeze(0).long()
            right_label[right_label != pred_idx] = 255

            for ii in range(len(BG_LABEL)):
                cls_idx = BG_LABEL[ii]
                mask = right_label == cls_idx
                if torch.sum(mask) == 0:
                    continue
                feature = global_avg_pool(src_feature, mask.float())
                if cls_idx != torch.argmax(
                        torch.squeeze(model.layer6(
                            feature.half()).float())).item():
                    continue
                src_cls_features[ii,
                                 int(src_cls_ptr[ii] %
                                     num_prototype), :] = torch.squeeze(
                                         feature).clone().detach()
                src_cls_ptr[ii] += 1

            seg_ins = seg_label(right_label.squeeze())
            for ii in range(len(FG_LABEL)):
                cls_idx = FG_LABEL[ii]
                segmask, pixelnum = seg_ins[ii]
                if len(pixelnum) == 0:
                    continue
                sortmax = np.argsort(pixelnum)[::-1]
                for i in range(min(10, len(sortmax))):
                    mask = segmask == (sortmax[i] + 1)
                    feature = global_avg_pool(src_feature, mask.float())
                    if cls_idx != torch.argmax(
                            torch.squeeze(
                                model.layer6(feature.half()).float())).item():
                        continue
                    src_ins_features[ii, int(src_ins_ptr[ii] %
                                             num_ins), :] = torch.squeeze(
                                                 feature).clone().detach()
                    src_ins_ptr[ii] += 1

            pred = interp(pred)
            loss_seg = seg_loss(pred, labels)
            loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            trg_feature, pred_target = model(images)

            pred_target_softmax = F.softmax(pred_target, dim=1)
            pred_target_idx = torch.argmax(pred_target_softmax, dim=1)

            loss_cls = torch.zeros(1).to(device)
            loss_ins = torch.zeros(1).to(device)
            if i_iter > 0:
                for ii in range(len(BG_LABEL)):
                    cls_idx = BG_LABEL[ii]
                    if src_cls_ptr[ii] / num_prototype <= 1:
                        continue
                    mask = pred_target_idx == cls_idx
                    feature = global_avg_pool(trg_feature, mask.float())
                    if cls_idx != torch.argmax(
                            torch.squeeze(
                                model.layer6(feature.half()).float())).item():
                        continue
                    ext_feature = feature.squeeze().expand(num_prototype, 2048)
                    loss_cls += torch.min(
                        torch.sum(L1_loss(ext_feature,
                                          src_cls_features[ii, :, :]),
                                  dim=1) / 2048.)

                seg_ins = seg_label(pred_target_idx.squeeze())
                for ii in range(len(FG_LABEL)):
                    cls_idx = FG_LABEL[ii]
                    if src_ins_ptr[ii] / num_ins <= 1:
                        continue
                    segmask, pixelnum = seg_ins[ii]
                    if len(pixelnum) == 0:
                        continue
                    sortmax = np.argsort(pixelnum)[::-1]
                    for i in range(min(10, len(sortmax))):
                        mask = segmask == (sortmax[i] + 1)
                        feature = global_avg_pool(trg_feature, mask.float())
                        feature = feature.squeeze().expand(num_ins, 2048)
                        loss_ins += torch.min(
                            torch.sum(L1_loss(feature,
                                              src_ins_features[ii, :, :]),
                                      dim=1) / 2048.) / min(10, len(sortmax))

            pred_target = interp_target(pred_target)

            D_out = model_D(F.softmax(pred_target, dim=1))
            loss_adv_target = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target * loss_adv_target + args.lambda_adv_cls * loss_cls + args.lambda_adv_ins * loss_ins
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_adv_target_value += loss_adv_target.item() / args.iter_size

            # train D

            # bring back requires_grad

            for param in model_D.parameters():
                param.requires_grad = True

            # train with source
            pred = pred.detach()
            D_out = model_D(F.softmax(pred, dim=1))

            loss_D = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))
            loss_D = loss_D / args.iter_size / 2
            amp_backward(loss_D, optimizer_D)
            loss_D_value += loss_D.item()

            # train with target
            pred_target = pred_target.detach()
            D_out = model_D(F.softmax(pred_target, dim=1))

            loss_D = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(target_label).to(device))
            loss_D = loss_D / args.iter_size / 2
            amp_backward(loss_D, optimizer_D)
            loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv = {3:.3f} loss_D = {4:.3f} loss_cls = {5:.3f} loss_ins = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_target_value, loss_D_value, loss_cls.item(),
                    loss_ins.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            testloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target_test,
                crop_size=(1024, 512),
                mean=IMG_MEAN,
                scale=False,
                mirror=False,
                set='val'),
                                         batch_size=1,
                                         shuffle=False,
                                         pin_memory=True)
            model.eval()
            for index, batch in enumerate(testloader):
                image, _, name = batch
                with torch.no_grad():
                    output1, output2 = model(Variable(image).to(device))
                output = test_interp(output2).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output = Image.fromarray(output)
                name = name[0].split('/')[-1]
                output.save('%s/%s' % (args.save, name))
            mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'),
                                 args.save, 'dataset/cityscapes_list')
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            if mIoU > bestIoU:
                bestIoU = mIoU
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5.pth'))
                torch.save(model_D.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5_D.pth'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
            model.train()

    if args.tensorboard:
        writer.close()
コード例 #28
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        logger.info("adopting Deeplabv2 base model..")
        model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.model == "FCN8S":
        logger.info("adopting FCN8S base model..")
        from pytorchgo.model.MyFCN8s import MyFCN8s
        model = MyFCN8s(n_class=NUM_CLASSES)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    else:
        raise ValueError

    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda()

    model_D2.train()
    model_D2.cuda()

    if SOURCE_DATA == "GTA5":
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    elif SOURCE_DATA == "SYNTHIA":
        trainloader = data.DataLoader(SynthiaDataSet(
            args.data_dir,
            args.data_list,
            LABEL_LIST_PATH,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    else:
        raise ValueError

    targetloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    best_mIoU = 0

    model_summary([model, model_D1, model_D2])
    optimizer_summary([optimizer, optimizer_D1, optimizer_D2])

    for i_iter in tqdm(range(args.num_steps_stop),
                       total=args.num_steps_stop,
                       desc="training"):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter)
        lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ######################### train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda()

            pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = loss_calc(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _, _ = batch
            images = Variable(images).cuda()

            pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            ################################## train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2
            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 100 == 0:
            logger.info(
                'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}'
                .format(i_iter, args.num_steps_stop, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2,
                        lr, lr_D1, best_mIoU))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            logger.info("saving snapshot.....")
            cur_miou16 = proceed_test(model, input_size)
            is_best = True if best_mIoU < cur_miou16 else False
            if is_best:
                best_mIoU = cur_miou16
            torch.save(
                {
                    'iteration': i_iter,
                    'optim_state_dict': optimizer.state_dict(),
                    'optim_D1_state_dict': optimizer_D1.state_dict(),
                    'optim_D2_state_dict': optimizer_D2.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'model_D1_state_dict': model_D1.state_dict(),
                    'model_D2_state_dict': model_D2.state_dict(),
                    'best_mean_iu': cur_miou16,
                }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
            if is_best:
                import shutil
                shutil.copy(
                    osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                    osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))

        if i_iter >= args.num_steps_stop - 1:
            break
コード例 #29
0
def main():

    # make dir
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    run = 0
    while os.path.exists("%s/run-%d" % (args.snapshot_dir, run)):
        run += 1
    os.mkdir("%s/run-%d" % (args.snapshot_dir, run))
    os.mkdir("%s/run-%d/models" % (args.snapshot_dir, run))
    args.file_dir = "%s/run-%d/file.txt" % (args.snapshot_dir, run)
    args.snapshot_dir = "%s/run-%d/models" % (args.snapshot_dir, run)

    # xuan xue you hua
    cudnn.enabled = True
    cudnn.benchmark = True

    # create the model
    model = build_model()
    model.to(device)
    model.train()
    model.apply(weights_init)
    model.load_state_dict(torch.load(args.restore_from))
    # model.base.load_pretrained_model(torch.load(args.pretrained_model))

    # create domintor
    model_D1 = FCDiscriminator(num_classes=1).to(device)
    model_D2 = FCDiscriminator(num_classes=1).to(device)
    model_D1.train()
    model_D2.train()
    model_D1.apply(weights_init)
    model_D2.apply(weights_init)
    # model_D1.load_state_dict(torch.load(args.D_restore_from))
    # model_D2.load_state_dict(torch.load(args.D_restore_from))

    # create optimizer
    optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=args.learning_rate)  # 整个模型的优化器
    optimizer.zero_grad()
    optimizer_D1 = optim.RMSprop(model_D1.parameters(),
                                 lr=args.learning_rate_D)
    optimizer_D1.zero_grad()
    optimizer_D2 = optim.RMSprop(model_D2.parameters(),
                                 lr=args.learning_rate_D)
    optimizer_D2.zero_grad()

    # start time
    with open(args.file_dir, 'a') as f:
        f.write('strat time: ' + str(datetime.now()) + '\n\n')

        f.write('learning rate: ' + str(args.learning_rate) + '\n')
        f.write('learning rate D: ' + str(args.learning_rate_D) + '\n')
        f.write('wight decay: ' + str(args.weight_decay) + '\n')
        f.write('lambda_adv_target2: ' + str(args.lambda_adv_target2) + '\n\n')

        f.write('eptch size: ' + str(args.epotch_size) + '\n')
        f.write('batch size: ' + str(args.batch_size) + '\n')
        f.write('iter size: ' + str(args.iter_size) + '\n')
        f.write('num steps: ' + str(args.num_steps) + '\n\n')

    # labels for adversarial training 两种域的记号
    salLabel = 0
    edgeLabel = 1

    picloader = get_loader(args)
    iter_num = len(picloader.dataset) // args.batch_size
    aveGrad = 0

    for i_epotch in range(args.epotch_size):
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        loss_adv_target_value1 = 0
        loss_adv_target_value2 = 0
        loss_D_value1 = 0
        loss_D_value2 = 0
        model.zero_grad()

        for i_iter, data_batch in enumerate(picloader):

            sal_image, sal_label, edge_image, edge_label = data_batch[
                'sal_image'], data_batch['sal_label'], data_batch[
                    'edge_image'], data_batch['edge_label']
            if (sal_image.size(2) != sal_label.size(2)) or (
                    sal_image.size(3) != sal_label.size(3)
                    or edge_image.size(2) != edge_label.size(2)) or (
                        edge_image.size(3) != edge_label.size(3)):
                print('IMAGE ERROR, PASSING```')
                with open(args.file_dir, 'a') as f:
                    f.write('IMAGE ERROR, PASSING```\n')
                continue

            sal_image, sal_label, edge_image, edge_label = Variable(
                sal_image), Variable(sal_label), Variable(
                    edge_image), Variable(edge_label)
            sal_image, sal_label, edge_image, edge_label = sal_image.to(
                device), sal_label.to(device), edge_image.to(
                    device), edge_label.to(device)

            s_sal_pred = model(sal_image, mode=1)
            s_edge_pred = model(edge_image, mode=1)
            e_sal_pred = model(sal_image, mode=0)
            e_edge_pred = model(edge_image, mode=0)

            # train G(with G)
            for param in model_D1.parameters():
                param.requires_grad = False
            for param in model_D2.parameters():
                param.requires_grad = False

            # sal
            sal_loss_fuse = F.binary_cross_entropy_with_logits(s_sal_pred,
                                                               sal_label,
                                                               reduction='sum')
            sal_loss = sal_loss_fuse / (args.iter_size * args.batch_size)
            loss_seg_value1 += sal_loss.data

            sal_loss.backward()

            loss_adv_target1 = torch.mean(
                model_D1(s_edge_pred))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            sd_loss = loss_adv_target1 / (args.iter_size * args.batch_size)
            loss_adv_target_value1 += sd_loss.data  # 记录专用

            sd_loss = sd_loss * args.lambda_adv_target2
            sd_loss.backward()

            # edge
            edge_loss_fuse = bce2d(e_edge_pred[0], edge_label, reduction='sum')
            edge_loss_part = []
            for ix in e_edge_pred[1]:
                edge_loss_part.append(bce2d(ix, edge_label, reduction='sum'))
            edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (
                args.iter_size * args.batch_size)
            loss_seg_value2 += edge_loss.data

            edge_loss.backward()

            loss_adv_target2 = -torch.mean(model_D2(
                e_sal_pred[0]))  # 后面一个相当于全部是正确答案的和前一个size相同的tensor
            for ix in e_sal_pred[1]:
                loss_adv_target2 += -torch.mean(model_D2(ix))
            ed_loss = loss_adv_target2 / (args.iter_size * args.batch_size) / (
                len(e_sal_pred[1]) + 1)
            loss_adv_target_value2 += ed_loss.data

            ed_loss = ed_loss * args.lambda_adv_target2
            ed_loss.backward()

            # train D
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            s_sal_pred = s_sal_pred.detach()
            s_edge_pred = s_edge_pred.detach()
            e_sal_pred = [
                e_sal_pred[0].detach(), [x.detach() for x in e_sal_pred[1]]
            ]
            e_edge_pred = [
                e_edge_pred[0].detach(), [x.detach() for x in e_edge_pred[1]]
            ]

            # sal
            ss_loss = torch.mean(model_D1(s_sal_pred))
            ss_Loss = ss_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += ss_Loss.data

            ss_Loss.backward()

            se_loss = -torch.mean(model_D1(s_edge_pred))
            se_Loss = se_loss / (args.iter_size * args.batch_size)
            loss_D_value1 += se_Loss.data

            se_Loss.backward()

            # edge
            es_loss = torch.mean(model_D2(e_sal_pred[0]))
            for ix in e_sal_pred[1]:
                es_loss += torch.mean(model_D2(ix))
            es_Loss = es_loss / (args.iter_size *
                                 args.batch_size) / (len(e_sal_pred[1]) + 1)
            loss_D_value2 += es_Loss.data

            es_Loss.backward()

            ee_loss = -torch.mean(model_D2(e_edge_pred[0]))
            for ix in e_edge_pred[1]:
                ee_loss += -torch.mean(model_D2(ix))
            ee_Loss = ee_loss / (args.iter_size *
                                 args.batch_size) / (len(e_edge_pred[1]) + 1)
            loss_D_value2 += ee_Loss.data

            ee_Loss.backward()

            aveGrad += 1
            if aveGrad % args.iter_size == 0:
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            optimizer_D1.step()
            for p in model_D1.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)
            optimizer_D1.zero_grad()
            optimizer_D2.step()
            for p in model_D2.parameters():
                p.data.clamp_(-args.clip_value, args.clip_value)
            optimizer_D2.zero_grad()

            if i_iter % (args.show_every // args.batch_size) == 0:
                print(
                    'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_seg2 = {7:.3f}, loss_adv1 = {3:.3f}, loss_adv2 = {8:.3f}, loss_D1 = {4:.3f}, loss_D2 = {9:.3f}'
                    .format(i_iter, iter_num, loss_seg_value1,
                            loss_adv_target_value1, loss_D_value1, i_epotch,
                            args.epotch_size, loss_seg_value2,
                            loss_adv_target_value2, loss_D_value2))
                with open(args.file_dir, 'a') as f:
                    f.write(
                        'epotch = {5:2d}/{6:2d}, iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_seg2 = {7:.3f}, loss_adv1 = {3:.3f}, loss_adv2 = {8:.3f}, loss_D1 = {4:.3f}, loss_D2 = {9:.3f}\n'
                        .format(i_iter, iter_num, loss_seg_value1,
                                loss_adv_target_value1, loss_D_value1,
                                i_epotch, args.epotch_size, loss_seg_value2,
                                loss_adv_target_value2, loss_D_value2))

                loss_seg_value1, loss_adv_target_value1, loss_D_value1, loss_seg_value2, loss_adv_target_value2, loss_D_value2 = 0, 0, 0, 0, 0, 0

            if i_iter == iter_num - 1 or i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                with open(args.file_dir, 'a') as f:
                    f.write('taking snapshot ...\n')
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_.pth'))
                torch.save(model_D1.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_D1.pth'))
                torch.save(model_D2.state_dict(),
                           osp.join(args.snapshot_dir, 'sal_D2.pth'))

        if i_epotch == 7:
            args.learning_rate = args.learning_rate * 0.1
            args.learning_rate_D = args.learning_rate_D * 0.1
            optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=args.learning_rate,
                                      weight_decay=args.weight_decay)
            optimizer_D1 = optim.RMSprop(model_D1.parameters(),
                                         lr=args.learning_rate_D)
            optimizer_D2 = optim.RMSprop(model_D2.parameters(),
                                         lr=args.learning_rate_D)

    # end
    with open(args.file_dir, 'a') as f:
        f.write('end time: ' + str(datetime.now()) + '\n')
コード例 #30
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Implemented by Bongjoon Hyun
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)
    #

    # Implemented by Bongjoon Hyun
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    #

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        # Implemented by Bongjoon Hyun
        bce_loss = torch.nn.MSELoss()
        #

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):
            # Implemented by Bongjoon Hyun
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = (loss_seg2 + args.lambda_seg * loss_seg1) / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_adv_target1 = bce_loss(D1_out, labels_source1)
            loss_adv_target2 = bce_loss(D2_out, labels_source2)

            loss = args.lambda_adv_target1 * loss_adv_target1 + \
                   args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size

            loss.backward()

            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D1_out = model_D1(F.softmax(pred1))
            D2_out = model_D2(F.softmax(pred2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_source1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_source2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_target1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(target_label)).cuda(args.gpu)
            labels_target2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(target_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_target1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_target2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()
            #

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print
            'save model ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print
            'taking snapshot ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
コード例 #31
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from[:4] == 'https':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                print('.'.join(i_parts[1:]), saved_state_dict[i])
                # print i_parts
        print("Key new")
        print(new_params.keys())
        print("your model new")
        print(saved_state_dict.keys())
        model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SynthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source
            _, batch = next(trainloader_iter)
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            )[0] / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'Synthia_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'Synthia_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'Synthia_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'Synthia_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'Synthia_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'Synthia_' + str(i_iter) + '_D2.pth'))