示例#1
0
class AD_Trainer(nn.Module):
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.class_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.often_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate)
            if args.restore_from[:4] == 'http' :
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http' :
                    if i_parts[1] !='fc' and i_parts[1] !='layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] =='module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim = args.num_classes).cuda() 
        self.D2 = MsImageDis(input_dim = args.num_classes).cuda() 
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim = 1)
        self.log_sm = torch.nn.LogSoftmax(dim = 1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True)
        self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1")

    def update_class_criterion(self, labels):
            weight = torch.FloatTensor(self.num_classes).zero_().cuda()
            weight += 1
            count = torch.FloatTensor(self.num_classes).zero_().cuda()
            often = torch.FloatTensor(self.num_classes).zero_().cuda()
            often += 1
            n, h, w = labels.shape
            for i in range(self.num_classes):
                count[i] = torch.sum(labels==i)
                if count[i] < 64*64*n: #small objective, original train size is 512*256
                    weight[i] = self.max_value
            if self.often_balance:
                often[count == 0] = self.max_value

            self.often_weight = 0.9 * self.often_weight + 0.1 * often 
            self.class_weight = weight * self.often_weight
            print(self.class_weight)
            return nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255)

    def update_class_criterion_t(self, labels):
            weight = torch.FloatTensor(self.num_classes).zero_().cuda()
            weight += 1
            count = torch.FloatTensor(self.num_classes).zero_().cuda()
            often = torch.FloatTensor(self.num_classes).zero_().cuda()
            often += 1
            n, h, w = labels.shape
            for i in range(self.num_classes):
                count[i] = torch.sum(labels==i)
                if count[i] < 64*64*n: #small objective, original train size is 512*256
                    weight[i] = self.max_value
            if self.often_balance:
                often[count == 0] = self.max_value

            self.often_weight_t = 0.9 * self.often_weight_t + 0.1 * often 
            self.class_weight_t = weight * self.often_weight_t
            print(self.class_weight_t)
            return nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255)

    def update_label(self, labels, prediction):
            criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none')
            #criterion = self.seg_loss
            loss = criterion(prediction, labels)
            print('original loss: %f'% self.seg_loss(prediction, labels) )
            #mm = torch.median(loss)
            loss_data = loss.data.cpu().numpy()
            mm = np.percentile(loss_data[:], self.only_hard_label)
            #print(m.data.cpu(), mm)
            labels[loss < mm] = 255
            return labels

    def update_variance(self, labels, pred1, pred2):
            criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none')
            kl_distance = nn.KLDivLoss( reduction = 'none')
            loss = criterion(pred1, labels)

            variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) 
            exp_variance = torch.exp(-variance)

            print(variance.shape)
            print('variance mean: %.4f'%torch.mean(exp_variance[:]))
            print('variance min: %.4f'%torch.min(exp_variance[:]))
            print('variance max: %.4f'%torch.max(exp_variance[:]))
            loss = torch.mean(loss*exp_variance) + torch.mean(variance)
            return loss

    def update_variance_t(self, labels, pred1, pred2):
            criterion = nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255, reduction = 'none')
            kl_distance = nn.KLDivLoss( reduction = 'none')
            loss = criterion(pred1, labels)

            variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) 
            exp_variance = torch.exp(-variance)

            print(variance.shape)
            print('variance mean: %.4f'%torch.mean(exp_variance[:]))
            print('variance min: %.4f'%torch.min(exp_variance[:]))
            print('variance max: %.4f'%torch.max(exp_variance[:]))
            loss = torch.mean(loss*exp_variance) + torch.mean(variance)
            return loss

    def update_loss(self, loss):
        if self.fp16:
            with amp.scale_loss(loss, self.gen_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

    def gen_update(self, images, images_t, labels, labels_t, i_iter):
            self.gen_opt.zero_grad()

            pred1, pred2 = self.G(images)
            pred1 = self.interp(pred1)
            pred2 = self.interp(pred2)

            if self.class_balance:            
                self.seg_loss = self.update_class_criterion(labels)

            # calculate seg loss weighted by kldivloss
            # loss_seg1 = self.update_variance(labels, pred1, pred2)
            # loss_seg2 = self.update_variance(labels, pred2, pred1)
            loss_seg1 = self.seg_loss(pred1, labels)
            loss_seg2 = self.seg_loss(pred2, labels)
 
            loss = loss_seg2 + self.lambda_seg * loss_seg1

            self.update_loss(loss)
            images_t = images_t.cuda()
            labels_t = labels_t.long().cuda()
            pred1_t, pred2_t = self.G(images_t)
            pred1_t = self.interp(pred1_t)
            pred2_t = self.interp(pred2_t)

            if self.class_balance:            
                self.seg_loss_t = self.update_class_criterion_t(labels_t)

            # calculate seg loss weighted by kldivloss
            loss_seg1_t = self.update_variance_t(labels_t, pred1_t, pred2_t)
            loss_seg2_t = self.update_variance_t(labels_t, pred2_t, pred1_t)
 
            loss = loss_seg2_t + self.lambda_seg * loss_seg1_t
            self.update_loss(loss)

            self.gen_opt.step()
            zero_loss = torch.zeros(1).cuda()
            return loss_seg1, loss_seg2, loss_seg1_t, loss_seg2_t, zero_loss, zero_loss, zero_loss, zero_loss, pred1, pred2, None, None
    
    def dis_update(self, pred1, pred2, pred_target1, pred_target2):
            self.dis1_opt.zero_grad()
            self.dis2_opt.zero_grad()
            pred1 = pred1.detach()
            pred2 = pred2.detach()
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            if self.multi_gpu:
                loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )
                loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )
            else:
                loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )
                loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )

            loss = loss_D1 + loss_D2
            if self.fp16:
                with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            self.dis1_opt.step()
            self.dis2_opt.step()
            return loss_D1, loss_D2
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)

        if args.restore_from is not None:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    # print i_parts
            model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SyntheticSmokeTrain(
        args={},
        dataset_limit=args.num_steps * args.iter_size * args.batch_size,
        image_shape=input_size,
        dataset_mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    targetloader = data.DataLoader(SimpleSmokeVal(args={},
                                                  image_size=input_size_target,
                                                  dataset_mean=IMG_MEAN),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
        # bce_loss_all = torch.nn.BCEWithLogitsLoss(reduction='none')
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
        # bce_loss_all = torch.nn.MSELoss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    # interp_domain = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            for param in interp_domain.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.detach().data.cpu().item(
            ) / args.iter_size
            loss_seg_value2 += loss_seg2.detach().data.cpu().item(
            ) / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            # w1 = torch.argmax(pred_target1.detach(), dim=1)
            # w2 = torch.argmax(pred_target2.detach(), dim=1)

            min_class1 = sorted([(k, v)
                                 for k, v in Counter(w1.ravel()).items()],
                                key=lambda x: x[1])[0][0]
            min_class2 = sorted([(k, v)
                                 for k, v in Counter(w2.ravel()).items()],
                                key=lambda x: x[1])[0][0]

            # m1 = torch.where(w1==min_class1)
            # m1c = torch.where(w1!=min_class1)
            # w1[m1] = 11
            # w1[m1c] = 1

            # m2 = torch.where(w2==min_class2)
            # m2c = torch.where(w2!=min_class2)
            # w2[m2] = 11
            # w2[m2c] = 1

            # D_out1 = interp_domain(D_out1)
            # D_out2 = interp_domain(D_out2)

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.detach().data.cpu(
            ).item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.detach().data.cpu(
            ).item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D2.pth'))
        writer.flush()
示例#3
0
class AD_Trainer(nn.Module):
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes,
                                  use_se=args.use_se,
                                  train_bn=args.train_bn,
                                  norm_style=args.norm_style,
                                  droprate=args.droprate)
            if args.restore_from[:4] == 'http':
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http':
                    if i_parts[1] != 'fc' and i_parts[1] != 'layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] == 'module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D2 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                                 lr=args.learning_rate,
                                 momentum=args.momentum,
                                 nesterov=True,
                                 weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim=1)
        self.log_sm = torch.nn.LogSoftmax(dim=1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size=args.crop_size,
                                  mode='bilinear',
                                  align_corners=True)
        self.interp_target = nn.Upsample(size=args.crop_size,
                                         mode='bilinear',
                                         align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G,
                                                  self.gen_opt,
                                                  opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1,
                                                    self.dis1_opt,
                                                    opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2,
                                                    self.dis2_opt,
                                                    opt_level="O1")

    def update_class_criterion(self, labels):
        weight = torch.FloatTensor(self.num_classes).zero_().cuda()
        weight += 1
        count = torch.FloatTensor(self.num_classes).zero_().cuda()
        often = torch.FloatTensor(self.num_classes).zero_().cuda()
        often += 1
        print(labels.shape)
        n, h, w = labels.shape
        for i in range(self.num_classes):
            count[i] = torch.sum(labels == i)
            if count[i] < 64 * 64 * n:  #small objective
                weight[i] = self.max_value
        if self.often_balance:
            often[count == 0] = self.max_value

        self.often_weight = 0.9 * self.often_weight + 0.1 * often
        self.class_weight = weight * self.often_weight
        print(self.class_weight)
        return nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255)

    def update_label(self, labels, prediction):
        criterion = nn.CrossEntropyLoss(weight=self.class_weight,
                                        ignore_index=255,
                                        reduction='none')
        #criterion = self.seg_loss
        loss = criterion(prediction, labels)
        print('original loss: %f' % self.seg_loss(prediction, labels))
        #mm = torch.median(loss)
        loss_data = loss.data.cpu().numpy()
        mm = np.percentile(loss_data[:], self.only_hard_label)
        #print(m.data.cpu(), mm)
        labels[loss < mm] = 255
        return labels

    def gen_update(self, images, images_t, labels, labels_t, i_iter):
        self.gen_opt.zero_grad()

        pred1, pred2 = self.G(images)
        pred1 = self.interp(pred1)
        pred2 = self.interp(pred2)

        if self.class_balance:
            self.seg_loss = self.update_class_criterion(labels)

        if self.only_hard_label > 0:
            labels1 = self.update_label(labels.clone(), pred1)
            labels2 = self.update_label(labels.clone(), pred2)
            loss_seg1 = self.seg_loss(pred1, labels1)
            loss_seg2 = self.seg_loss(pred2, labels2)
        else:
            loss_seg1 = self.seg_loss(pred1, labels)
            loss_seg2 = self.seg_loss(pred2, labels)

        loss = loss_seg2 + self.lambda_seg * loss_seg1

        # target
        pred_target1, pred_target2 = self.G(images_t)
        pred_target1 = self.interp_target(pred_target1)
        pred_target2 = self.interp_target(pred_target2)

        if self.multi_gpu:
            #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0:
            loss_adv_target1 = self.D1.module.calc_gen_loss(
                self.D1, input_fake=F.softmax(pred_target1, dim=1))
            loss_adv_target2 = self.D2.module.calc_gen_loss(
                self.D2, input_fake=F.softmax(pred_target2, dim=1))
            #else:
            #    print('skip the discriminator')
            #    loss_adv_target1, loss_adv_target2 = 0, 0
        else:
            #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0:
            loss_adv_target1 = self.D1.calc_gen_loss(self.D1,
                                                     input_fake=F.softmax(
                                                         pred_target1, dim=1))
            loss_adv_target2 = self.D2.calc_gen_loss(self.D2,
                                                     input_fake=F.softmax(
                                                         pred_target2, dim=1))
            #else:
            #loss_adv_target1 = 0.0 #torch.tensor(0).cuda()
            #loss_adv_target2 = 0.0 #torch.tensor(0).cuda()

        loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2

        if i_iter < 15000:
            self.lambda_kl_target_copy = 0
            self.lambda_me_target_copy = 0
        else:
            self.lambda_kl_target_copy = self.lambda_kl_target
            self.lambda_me_target_copy = self.lambda_me_target

        loss_me = 0.0
        if self.lambda_me_target_copy > 0:
            confidence_map = torch.sum(
                self.sm(0.5 * pred_target1 + pred_target2)**2, 1).detach()
            loss_me = -torch.mean(confidence_map * torch.sum(
                self.sm(0.5 * pred_target1 + pred_target2) *
                self.log_sm(0.5 * pred_target1 + pred_target2), 1))
            loss += self.lambda_me_target * loss_me

        loss_kl = 0.0
        if self.lambda_kl_target_copy > 0:
            n, c, h, w = pred_target1.shape
            with torch.no_grad():
                #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t))
                #pred_target1_flip = self.interp_target(pred_target1_flip)
                #pred_target2_flip = self.interp_target(pred_target2_flip)
                mean_pred = self.sm(
                    0.5 * pred_target1 + pred_target2
                )  #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2
            loss_kl = (self.kl_loss(self.log_sm(pred_target2), mean_pred) +
                       self.kl_loss(self.log_sm(pred_target1), mean_pred)) / (
                           n * h * w)
            #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w)
            print(loss_kl)
            loss += self.lambda_kl_target * loss_kl

        if self.fp16:
            with amp.scale_loss(loss, self.gen_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.gen_opt.step()

        val_loss = self.seg_loss(pred_target2, labels_t)

        return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss

    def dis_update(self, pred1, pred2, pred_target1, pred_target2):
        self.dis1_opt.zero_grad()
        self.dis2_opt.zero_grad()
        pred1 = pred1.detach()
        pred2 = pred2.detach()
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        if self.multi_gpu:
            loss_D1, reg1 = self.D1.module.calc_dis_loss(
                self.D1,
                input_fake=F.softmax(pred_target1, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
            loss_D2, reg2 = self.D2.module.calc_dis_loss(
                self.D2,
                input_fake=F.softmax(pred_target2, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
        else:
            loss_D1, reg1 = self.D1.calc_dis_loss(
                self.D1,
                input_fake=F.softmax(pred_target1, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
            loss_D2, reg2 = self.D2.calc_dis_loss(
                self.D2,
                input_fake=F.softmax(pred_target2, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))

        loss = loss_D1 + loss_D2
        if self.fp16:
            with amp.scale_loss(loss,
                                [self.dis1_opt, self.dis2_opt]) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.dis1_opt.step()
        self.dis2_opt.step()
        return loss_D1, loss_D2
def main():
    """Create the model and start the training."""
    global args
    args = get_arguments()
    if args.dist:
        init_dist(args.launcher, backend=args.backend)
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'Deeplab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from, strict=False)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict, strict=False)
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict, strict=False)
            del saved_state_dict

    model.train()
    model.to(device)
    if args.dist:
        broadcast_params(model)

    if rank == 0:
        print(model)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)
    if args.dist:
        broadcast_params(model_D1)
    if args.restore_D is not None:
        D_dict = torch.load(args.restore_D)
        model_D1.load_state_dict(D_dict, strict=False)
        del D_dict

    model_D2.train()
    model_D2.to(device)
    if args.dist:
        broadcast_params(model_D2)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_data = GTA5BDDDataSet(args.data_dir,
                                args.data_list,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size,
                                scale=args.random_scale,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)
    trainloader = data.DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False if train_sampler else True,
                                  num_workers=args.num_workers,
                                  pin_memory=False,
                                  sampler=train_sampler)

    trainloader_iter = enumerate(cycle(trainloader))

    target_data = BDDDataSet(args.data_dir_target,
                             args.data_list_target,
                             max_iters=args.num_steps * args.iter_size *
                             args.batch_size,
                             crop_size=input_size_target,
                             scale=False,
                             mirror=args.random_mirror,
                             mean=IMG_MEAN,
                             set=args.set)
    target_sampler = None
    if args.dist:
        target_sampler = DistributedSampler(target_data)
    targetloader = data.DataLoader(target_data,
                                   batch_size=args.batch_size,
                                   shuffle=False if target_sampler else True,
                                   num_workers=args.num_workers,
                                   pin_memory=False,
                                   sampler=target_sampler)

    targetloader_iter = enumerate(cycle(targetloader))

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard and rank == 0:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    torch.cuda.empty_cache()
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, size, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            interp = nn.Upsample(size=(size[1], size[0]),
                                 mode='bilinear',
                                 align_corners=True)

            pred1 = model(images)
            pred1 = interp(pred1)

            loss_seg1 = seg_loss(pred1, labels)

            loss = loss_seg1

            # proper normalization
            loss = loss / args.iter_size / world_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size

            _, batch = targetloader_iter.__next__()
            # train with target
            images, _, _ = batch
            images = images.to(device)

            pred_target1 = model(images)
            pred_target1 = interp_target(pred_target1)

            D_out1 = model_D1(F.softmax(pred_target1))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1
            loss = loss / args.iter_size / world_size

            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            D_out1 = model_D1(F.softmax(pred1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with target
            pred_target1 = pred_target1.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            if args.dist:
                average_gradients(model)
                average_gradients(model_D1)
                average_gradients(model_D2)

            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        if rank == 0:
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1 * world_size,
                    'loss_D2': loss_D_value2 * world_size,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D1.pth'))
    print(args.snapshot_dir)
    if args.tensorboard and rank == 0:
        writer.close()
示例#5
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = DeeplabMulti(num_classes=args.num_classes)
    #model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from, map_location='cuda:0')

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)
    #summary(model,(3,7,7))

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    #summary(model_D, (21,321,321))
    #quit()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                      args.iter_size * args.batch_size,
                                      scale=args.random_scale)
    train_dataset_size = len(train_dataset)
    train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                         args.iter_size * args.batch_size,
                                         scale=args.random_scale)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             sampler=train_remain_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         sampler=train_gt_sampler,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                try:
                    pred = interp(model(images))
                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        print("WARNING: out of memory")
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise exception

                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv / args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(
                        semi_ignore_mask.sum()) / semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(
                            pred, semi_gt, args.gpu)
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            try:
                pred = interp(model(images))
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.__next__()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.__next__()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth'))
            #torch.cuda.empty_cache()

    end = timeit.default_timer()
    print(end - start, 'seconds')
示例#6
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        # if args.restore_from[:4] == 'http' :
        #     saved_state_dict = model_zoo.load_url(args.restore_from)
        # else:
        #     saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        # for i in saved_state_dict:
        #     # Scale.layer5.conv2d_list.3.weight
        #     i_parts = i.split('.')
        #     # print i_parts
        #     if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #         new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #         # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # trainloader = data.DataLoader(
    #     GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    train_set = spacenet.Spacenet(city=config.dataset, split='train', img_root=config.img_root)
    trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)

    trainloader_iter = enumerate(trainloader)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)

    target_set = spacenet.Spacenet(city=config.target, split='train', img_root=config.img_root)
    targetloader = DataLoader(target_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except StopIteration:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels = batch
            # print(images)
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            # loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            try:
                _, batch = targetloader_iter.__next__()
            except StopIteration:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            # loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy()[0] / args.iter_size
            # loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy()[0] / args.iter_size
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D2.pth'))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    tau = torch.ones(1) * args.tau
    tau = tau.cuda(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params, False)
    elif args.model == 'DeepLabVGG':
        model = DeeplabVGG(pretrained=True, num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    weak_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Resize(1024),
        transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #         RandomCrop(768)
    ])

    target_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Normalize(mean, std)
        #         transforms.Resize(1024),
        #         transforms.ToTensor(),
        #         RandomCrop(768)
    ])

    label_set = GTA5(
        root=args.data_dir,
        num_cls=19,
        split='all',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size,
        #              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )
    unlabel_set = Cityscapes(
        root=args.data_dir_target,
        split=args.set,
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )

    test_set = Cityscapes(
        root=args.data_dir_target,
        split='val',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                       crop_transform=RandomCrop(768)
    )

    label_loader = data.DataLoader(label_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=False)

    unlabel_loader = data.DataLoader(unlabel_set,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=False)

    test_loader = data.DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=False)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    [model, model_D2,
     model_D2], [optimizer, optimizer_D1, optimizer_D2
                 ] = amp.initialize([model, model_D2, model_D2],
                                    [optimizer, optimizer_D1, optimizer_D2],
                                    opt_level="O1",
                                    num_losses=7)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = Interpolate(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = Interpolate(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_test = Interpolate(size=(input_size_target[1],
                                    input_size_target[0]),
                              mode='bilinear',
                              align_corners=True)
    #     interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True)

    normalize_transform = transforms.Compose([
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    # labels for adversarial training
    source_label = 0
    target_label = 1

    max_mIoU = 0

    total_loss_seg_value1 = []
    total_loss_adv_target_value1 = []
    total_loss_D_value1 = []
    total_loss_con_value1 = []

    total_loss_seg_value2 = []
    total_loss_adv_target_value2 = []
    total_loss_D_value2 = []
    total_loss_con_value2 = []

    hist = np.zeros((num_cls, num_cls))

    #     for i_iter in range(args.num_steps):
    for i_iter, (batch, batch_un) in enumerate(
            zip(roundrobin_infinite(label_loader),
                roundrobin_infinite(unlabel_loader))):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_con_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_con_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D1.parameters():
            param.requires_grad = False

        for param in model_D2.parameters():
            param.requires_grad = False

        # train with source

        images, labels = batch
        images_orig = images
        images = transform_batch(images, normalize_transform)
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

#         loss.backward()
        loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

        # train with target

        images_tar, labels_tar = batch_un
        images_tar_orig = images_tar
        images_tar = transform_batch(images_tar, normalize_transform)
        images_tar = Variable(images_tar).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_tar)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_adv_target1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv_target2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()
        loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
        ) / args.iter_size
        loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
        ) / args.iter_size

        # train with consistency loss
        # unsupervise phase
        policies = RandAugment().get_batch_policy(args.batch_size)
        rand_p1 = np.random.random(size=args.batch_size)
        rand_p2 = np.random.random(size=args.batch_size)
        random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2])

        images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1,
                                      rand_p2, random_dir)

        images_aug_orig = images_aug
        images_aug = transform_batch(images_aug, normalize_transform)
        images_aug = Variable(images_aug).cuda(args.gpu)

        pred_target_aug1, pred_target_aug2 = model(images_aug)
        pred_target_aug1 = interp_target(pred_target_aug1)
        pred_target_aug2 = interp_target(pred_target_aug2)

        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1)
        max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1)

        psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32)
        psuedo_label1_thre = psuedo_label1.copy()
        psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies,
                                             rand_p1, rand_p2, random_dir)
        psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32)
        psuedo_label2_thre = psuedo_label2.copy()
        psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies,
                                             rand_p1, rand_p2, random_dir)

        psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu)
        psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu)

        if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre,
                                  args.gpu)
            loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size
        else:
            loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre,
                                  args.gpu)
            loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size
        else:
            loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2
        # proper normalization
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()

# train D

# bring back requires_grad
        for param in model_D1.parameters():
            param.requires_grad = True

        for param in model_D2.parameters():
            param.requires_grad = True

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        D_out1 = model_D1(F.softmax(pred1, dim=1))
        D_out2 = model_D2(F.softmax(pred2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        # train with target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_con_value1,
                    loss_con_value2))

        total_loss_seg_value1.append(loss_seg_value1)
        total_loss_adv_target_value1.append(loss_adv_target_value1)
        total_loss_D_value1.append(loss_D_value1)
        total_loss_con_value1.append(loss_con_value1)

        total_loss_seg_value2.append(loss_seg_value2)
        total_loss_adv_target_value2.append(loss_adv_target_value2)
        total_loss_D_value2.append(loss_D_value2)
        total_loss_con_value2.append(loss_con_value2)

        hist += fast_hist(
            labels.cpu().numpy().flatten().astype(int),
            torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int),
            num_cls)

        if i_iter % 10 == 0:
            print('({}/{})'.format(i_iter + 1, int(args.num_steps)))
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(np.mean(iu)))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_train_acc = acc_overall
            avg_train_loss_seg1 = np.mean(total_loss_seg_value1)
            avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1)
            avg_train_loss_dis1 = np.mean(total_loss_D_value1)
            avg_train_loss_con1 = np.mean(total_loss_con_value1)
            avg_train_loss_seg2 = np.mean(total_loss_seg_value2)
            avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2)
            avg_train_loss_dis2 = np.mean(total_loss_D_value2)
            avg_train_loss_con2 = np.mean(total_loss_con_value2)

            print('avg_train_acc      :', avg_train_acc)
            print('avg_train_loss_seg1 :', avg_train_loss_seg1)
            print('avg_train_loss_adv1 :', avg_train_loss_adv1)
            print('avg_train_loss_dis1 :', avg_train_loss_dis1)
            print('avg_train_loss_con1 :', avg_train_loss_con1)
            print('avg_train_loss_seg2 :', avg_train_loss_seg2)
            print('avg_train_loss_adv2 :', avg_train_loss_adv2)
            print('avg_train_loss_dis2 :', avg_train_loss_dis2)
            print('avg_train_loss_con2 :', avg_train_loss_con2)

            writer['train'].add_scalar('log/mIoU', mIoU, i_iter)
            writer['train'].add_scalar('log/acc', avg_train_acc, i_iter)
            writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2,
                                       i_iter)

            hist = np.zeros((num_cls, num_cls))
            total_loss_seg_value1 = []
            total_loss_adv_target_value1 = []
            total_loss_D_value1 = []
            total_loss_con_value1 = []
            total_loss_seg_value2 = []
            total_loss_adv_target_value2 = []
            total_loss_D_value2 = []
            total_loss_con_value2 = []

            fig = plt.figure(figsize=(15, 15))

            labels = labels[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(331)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(337)
            images = images_orig[0].cpu().numpy().transpose((1, 2, 0))
            #             images += IMG_MEAN
            ax.imshow(images)
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(334)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            labels_tar = labels_tar[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(332)
            ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L')))
            ax.axis("off")
            ax.set_title('tar_labels')

            ax = fig.add_subplot(338)
            ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('tar_datas')

            _, pred_target2 = torch.max(pred_target2, dim=1)
            pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(335)
            ax.imshow(print_palette(
                Image.fromarray(pred_target2).convert('L')))
            ax.axis("off")
            ax.set_title('tar_predicts')

            print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0],
                  'random_dir', random_dir[0])

            psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(333)
            ax.imshow(
                print_palette(
                    Image.fromarray(psuedo_label2_thre).convert('L')))
            ax.axis("off")
            ax.set_title('psuedo_labels')

            ax = fig.add_subplot(339)
            ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('aug_datas')

            _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1)
            pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(336)
            ax.imshow(
                print_palette(Image.fromarray(pred_target_aug2).convert('L')))
            ax.axis("off")
            ax.set_title('aug_predicts')

            #             plt.show()
            writer['train'].add_figure('image/',
                                       fig,
                                       global_step=i_iter,
                                       close=True)

        if i_iter % 500 == 0:
            loss1 = []
            loss2 = []
            for test_i, batch in enumerate(test_loader):

                images, labels = batch
                images_orig = images
                images = transform_batch(images, normalize_transform)
                images = Variable(images).cuda(args.gpu)

                pred1, pred2 = model(images)
                pred1 = interp_test(pred1)
                pred1 = pred1.detach()
                pred2 = interp_test(pred2)
                pred2 = pred2.detach()

                loss_seg1 = loss_calc(pred1, labels, args.gpu)
                loss_seg2 = loss_calc(pred2, labels, args.gpu)
                loss1.append(loss_seg1.item())
                loss2.append(loss_seg2.item())

                hist += fast_hist(
                    labels.cpu().numpy().flatten().astype(int),
                    torch.argmax(pred2,
                                 dim=1).cpu().numpy().flatten().astype(int),
                    num_cls)

            print('test')
            fig = plt.figure(figsize=(15, 15))
            labels = labels[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(311)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(313)
            ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(312)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            #             plt.show()

            writer['test'].add_figure('test_image/',
                                      fig,
                                      global_step=i_iter,
                                      close=True)

            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(mIoU))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_test_loss1 = np.mean(loss1)
            avg_test_loss2 = np.mean(loss2)
            avg_test_acc = acc_overall
            print('avg_test_loss2 :', avg_test_loss1)
            print('avg_test_loss1 :', avg_test_loss2)
            print('avg_test_acc   :', avg_test_acc)
            writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter)
            writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter)
            writer['test'].add_scalar('log/acc', avg_test_acc, i_iter)
            writer['test'].add_scalar('log/mIoU', mIoU, i_iter)

            hist = np.zeros((num_cls, num_cls))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if max_mIoU < mIoU:
            max_mIoU = mIoU
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
def main():
    """Create the model and start the training."""

    gpu_id_2 = 1
    gpu_id_1 = 0

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(
                '/data2/zhangjunyi/snapshots/snapshots_syn/onlysyn/GTA5_80000.pth',
                map_location={"cuda:3": "cuda:0"})

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if (not i_parts[1] == 'layer5') and (not i_parts[0] == 'fc'):
                new_params['.'.join(i_parts)] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes, match=False)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SynthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.__next__()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        cut_size=cut_size,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.__next__()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    mse_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        number1 = 0
        number2 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            beta = args.beta
            if i_iter == 0:
                print(beta)
            _, batch = trainloader_iter.__next__()
            _, batch_target = targetloader_iter.__next__()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number1 = torch.sum(labels == 255).item()
            loss_seg2, new_labels = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number2 = torch.sum(labels == 255).item()
            loss_seg2, new_lables = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)
            # exit()

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D1(fake2_D)

            loss_adv_fake1 = mse_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = mse_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D1(mix2_D)

            loss_adv_mix1 = mse_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = mse_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value1 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D2
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            D_out_target_1 = model_D2(pred1_target_D)
            D_out_target_2 = model_D2(pred2_target_D)

            loss_adv_target1 = bce_loss(
                D_out_target_1,
                Variable(
                    torch.FloatTensor(D_out_target_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target2 = bce_loss(
                D_out_target_2,
                Variable(
                    torch.FloatTensor(D_out_target_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D1(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D1(mix2_D_)

            loss_D1 = mse_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = mse_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = mse_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = mse_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value1 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_last_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_last_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D1(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D1(mix2_D__)

            loss_D1 = mse_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = mse_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = mse_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = mse_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value1 += loss_D2.data.cpu().numpy()

            # train model-D2
            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            D_out1 = model_D2(pred1_D)
            D_out2 = model_D2(pred2_D)
            D_out3 = model_D2(pred1_target_D)
            D_out4 = model_D2(pred2_target_D)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value2 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, number1 = {8}, number2 = {9}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, number1, number2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main(args):
    """Create the model and start the training."""

    mkdir_check(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.snapshot_dir, 'tb-logs'))

    h, w = map(int, args.src_input_size.split(','))
    src_input_size = (h, w)

    h, w = map(int, args.tgt_input_size.split(','))
    tgt_input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=19)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
            ListDataSet(args.src_data_dir, args.src_img_list, args.src_lbl_list, 
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=src_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_list, args.tgt_lbl_list,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_nolabel = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_nolabel_list, None,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    targetloader_nolabel_iter = enumerate(targetloader_nolabel)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(src_input_size[1], src_input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(tgt_input_size[1], tgt_input_size[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_tgt_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(args, optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(args, optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            #pred1 = interp(pred1)
            pred2 = interp(pred2)

            #loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 #+ args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target seg

            _, batch = targetloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            #loss_tgt_seg1 = loss_calc(pred_target1, labels, args.gpu)
            loss_tgt_seg2 = loss_calc(pred_target2, labels, args.gpu)
            loss = loss_tgt_seg2 #+ args.lambda_seg * loss_tgt_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward(retain_graph=True)
            loss_tgt_seg_value2 += loss_tgt_seg2.data.cpu().numpy() / args.iter_size

            # train with target_nolabel adv
            _, batch = targetloader_nolabel_iter.__next__()
            images, _, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()

            D_out2 = model_D2(F.softmax(pred2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D2.step()

        print(
        'iter = {:5d}/{:8d}, loss_seg2 = {:.3f} loss_tgt_seg2 = {:.3f} loss_adv2 = {:.3f} loss_D2 = {:.3f}'.format(
            i_iter, args.num_steps_stop, loss_seg_value2, loss_tgt_seg_value2,
            loss_adv_target_value2, loss_D_value2))

        writer.add_scalars('loss/seg', {
            'src2': loss_seg_value2, 
            'tgt2': loss_tgt_seg_value2,
            }, i_iter)

        writer.add_scalar('loss/adv', loss_adv_target_value2, i_iter)
        writer.add_scalar('loss/d', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
示例#10
0
def main():
    """Create the model and start the training."""
    setup_seed(666)
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            if i_parts[1]=='layer4' and i_parts[2]=='2':
                i_parts[1] = 'layer5'
                i_parts[2] = '0'
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            else:
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    model.load_state_dict(new_params)
    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(device) if i<1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in range(2)])
    
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)

        loss_adv = 0
        D_out = model_D[0](feat_target)
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        D_out = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv*0.01
        loss_adv.backward()
        
        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        loss_D_source = 0
        D_out_source = model_D[0](feat_source.detach())
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        D_out_source = model_D[1](F.softmax(pred_source.detach(),dim=1))
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        loss_D_source.backward()

        # train with target
        loss_D_target = 0
        D_out_target = model_D[0](feat_target.detach())
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        D_out_target = model_D[1](F.softmax(pred_target.detach(),dim=1))
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        loss_D_target.backward()
        
        optimizer_D.step()

        if i_iter % 10 == 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D_s = {4:.3f}, loss_D_t = {5:.3f}'.format(
            i_iter, args.num_steps, loss_seg.item(), loss_adv.item(), loss_D_source.item(), loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D parameters
        restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D)
        model_D.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    """
    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        adv_loss_value = 0
        d_loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate(optimizer_D, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            """
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False
            """

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # size (1, 19, 720, 1280)
            # pdb.set_trace()

            # feature = nn.Softmax(dim=1)(pred1)
            # softmax_out = nn.Softmax(dim=1)(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            _, batch = targetloader_iter.__next__()

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=False)

            images, _, _ = batch
            images = images.to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred_target1, pred_target2 = model(images)

            # pred_target1, 2 == [1, 19, 91, 161]
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            # pred_target1, 2 == [1, 19, 720, 1280]
            # pdb.set_trace()

            # feature_target = nn.Softmax(dim=1)(pred_target1)
            # softmax_out_target = nn.Softmax(dim=1)(pred_target2)

            # features = torch.cat((pred1, pred_target1), dim=0)
            # outputs = torch.cat((pred2, pred_target2), dim=0)
            # features.size() == [2, 19, 720, 1280]
            # softmax_out.size() == [2, 19, 720, 1280]
            # pdb.set_trace()
            # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None)
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')
            dc_source = torch.FloatTensor(
                D_out_target.size()).fill_(0).to(device)
            # pdb.set_trace()
            adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source)
            adv_loss = adv_loss / args.iter_size
            adv_loss = args.lambda_adv * adv_loss
            # pdb.set_trace()
            # classifier_loss = nn.BCEWithLogitsLoss()(pred2,
            #        torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda())
            # pdb.set_trace()
            adv_loss.backward()
            adv_loss_value += adv_loss.item()
            # optimizer_D.step()
            #TODO: normalize loss?

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=True)

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            D_out = CDAN([F.softmax(pred1), F.softmax(pred2)],
                         model_D,
                         cdan_implement='concat')

            dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device)
            # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None)
            d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')

            dc_target = torch.FloatTensor(
                D_out_target.size()).fill_(1).to(device)
            d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            continue

        optimizer.step()
        optimizer_D.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'generator_loss': adv_loss_value,
            'discriminator_loss': d_loss_value,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    adv_loss_value, d_loss_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)
            save_path = args.snapshot_dir + '/eval_{}'.format(i_iter)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            # evaluate(args, save_path, args.snapshot_dir, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
                i_iter)
            pdb.set_trace()
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    def _forward_single(_iter):
        _scalar = torch.tensor(1.).cuda().requires_grad_()

        _, batch = _iter.__next__()
        images, labels, _, _ = batch
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model.forward_irm(images, _scalar)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        # compute penalty
        grad = autograd.grad(loss, [_scalar], create_graph=True)[0]
        penalty = torch.sum(grad**2)

        loss += args.lambda_irm * penalty

        loss.backward()

        return loss_seg1, loss_seg2

    for i_iter in range(args.num_steps):

        loss_seg1_src = 0
        loss_seg2_src = 0

        loss_seg1_tgt = 0
        loss_seg2_tgt = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):

            # def _forward_single(_loader, _model)
            #     _, batch = _loader_iter.next()
            #     images, labels, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred1, pred2 = _model(images)
            #     pred1 = interp(pred1)
            #     pred2 = interp(pred2)

            #     loss_seg1 = loss_calc(pred1, labels, args.gpu)
            #     loss_seg2 = loss_calc(pred2, labels, args.gpu)
            #     loss = loss_seg2 + args.lambda_seg * loss_seg1

            #     # proper normalization
            #     loss = loss / args.iter_size
            #     loss.backward()

            #     return loss_seg1, loss_seg2

            # _, batch = targetloader_iter.next()
            # images, _, _ = batch
            # images = Variable(images).cuda(args.gpu)

            # pred1_tgt, pred2_tgt = model(images)
            # pred1_tgt = interp_target(pred1_tgt)
            # pred2_tgt = interp_target(pred2_tgt)

            # loss_seg1 = loss_calc(pred1, labels, args.gpu)
            # loss_seg2 = loss_calc(pred2, labels, args.gpu)
            # loss = loss_seg2 + args.lambda_seg * loss_seg1
            # loss = loss / args.iter_size
            # loss.backward()

            # train with source
            loss_seg1, loss_seg2 = _forward_single(trainloader_iter)
            loss_seg1_src += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_src += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target
            loss_seg1, loss_seg2 = _forward_single(targetloader_iter)
            loss_seg1_tgt += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_tgt += loss_seg2.data.cpu().numpy()[0] / args.iter_size

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
示例#13
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        #model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #     model_D1.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D1.pth'))
    #     model_D2.load_state_dict(torch.load('./snapshots/local_00002/GTA5_21000_D2.pth'))

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Load VGG
    #vgg19 = torchvision.models.vgg19(pretrained=True)
    #vgg19.to(device)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(0, args.num_steps):

        loss_seg_value = 0
        #         loss_seg_local_value = 0
        loss_adv_target_value = 0
        #         loss_adv_local_value = 0
        loss_D_value = 0
        #         loss_D_local_value = 0
        loss_local_match_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred_s1, pred_s2, _, _ = model(images)
            #f_s2 = normalize(f_s2)
            pred_s1, pred_s2 = interp(pred_s1), interp(pred_s2)

            loss_seg = args.lambda_seg * seg_loss(pred_s1, labels) + seg_loss(
                pred_s2, labels)
            del labels

            # proper normalization
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target
            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_t1, pred_t2, _, _ = model(images)
            #f_t2 = normalize(f_t2)
            pred_t1, pred_t2 = interp_target(pred_t1), interp_target(pred_t2)
            del images

            D_out_1 = model_D1(F.softmax(pred_t1, dim=1))
            D_out_2 = model_D2(F.softmax(pred_t2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out_1,
                torch.FloatTensor(
                    D_out_1.data.size()).fill_(source_label).to(device))
            loss_adv_target2 = bce_loss(
                D_out_2,
                torch.FloatTensor(
                    D_out_2.data.size()).fill_(source_label).to(device))
            loss_adv_target_value += (
                args.lambda_adv_target1 * loss_adv_target1 +
                args.lambda_adv_target *
                loss_adv_target2).item() / args.iter_size

            loss = loss_seg + args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target * loss_adv_target2

            del D_out_1, D_out_2

            #             #< Local patch part>#
            #             corres_id2 = get_correspondance(f_s2, f_t2, pred_s2, pred_t2)

            #             #loss_local1 = local_feature_loss(corres_id1, f_s1, f_t1, model, seg_loss)
            #             loss_local2 = local_feature_loss(corres_id2, labels, f_t2, model, seg_loss)
            #             loss_local = args.lambda_match_target2 * loss_local2 #+args.lambda_match_target1 * loss_local1
            #             loss += loss_local
            #             if corres_id2.nelement() > 0:
            #                 loss_local_match_value += loss_local.item()/ args.iter_size

            loss /= args.iter_size
            loss.backward()

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_s1)), model_D2(
                F.softmax(pred_s2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(source_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

            # train with target
            pred_t1, pred_t2 = pred_t1.detach(), pred_t2.detach()
            D_out_1, D_out_2 = model_D1(F.softmax(pred_t1)), model_D2(
                F.softmax(pred_t2))
            loss_D_1 = bce_loss(
                D_out_1,
                torch.FloatTensor(D_out_1.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_2 = bce_loss(
                D_out_2,
                torch.FloatTensor(D_out_2.data.size()).fill_(target_label).to(
                    device)) / args.iter_size / 2
            loss_D_1.backward()
            loss_D_2.backward()

            loss_D_value += (loss_D_1 + loss_D_2).item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 1000 == 0:
            val_dir = '../dataset/Cityscapes/leftImg8bit_trainvaltest/'
            val_list = './dataset/cityscapes_list/val.txt'
            save_dir = './results/tmp'
            gt_dir = '../dataset/Cityscapes/gtFine_trainvaltest/gtFine/val'
            evaluate_cityscapes.test_model(model, device, val_dir, val_list,
                                           save_dir)
            mIoU = compute_iou.mIoUforTest(gt_dir, save_dir)

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                #'loss_seg_local': loss_seg_local_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_local_match': loss_local_match_value,
                'loss_D': loss_D_value,
                'mIoU': mIoU
                #'loss_D_local': loss_D_local_value
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f}, loss_D = {4:.3f}, loss_local_match = {5:.3f}, mIoU = {6:3f} '
        #'loss_seg_local = {5:.3f} loss_adv_local = {6:.3f}, loss_D_local = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value, loss_adv_target_value, loss_D_value, loss_local_match_value, mIoU)
                    #loss_seg_local_value, loss_adv_local_value, loss_D_local_value)
        )

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
def main():
    """Create the model and start the training."""
    or_nyu_dict = {
        0: 255,
        1: 16,
        2: 40,
        3: 39,
        4: 7,
        5: 14,
        6: 39,
        7: 12,
        8: 38,
        9: 40,
        10: 10,
        11: 6,
        12: 40,
        13: 39,
        14: 39,
        15: 40,
        16: 18,
        17: 40,
        18: 4,
        19: 40,
        20: 40,
        21: 5,
        22: 40,
        23: 40,
        24: 30,
        25: 36,
        26: 38,
        27: 40,
        28: 3,
        29: 40,
        30: 40,
        31: 9,
        32: 38,
        33: 40,
        34: 40,
        35: 40,
        36: 34,
        37: 37,
        38: 40,
        39: 40,
        40: 39,
        41: 8,
        42: 3,
        43: 1,
        44: 2,
        45: 22
    }
    or_nyu_map = lambda x: or_nyu_dict.get(x, x) - 1
    or_nyu_map = np.vectorize(or_nyu_map)

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    args.or_nyu_map = or_nyu_map
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from == "":
            saved_state_dict = None
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 40 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True
    if args.mode != "baseline" and args.mode != "baseline_tar":
        # init D
        model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
        model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

        model_D1.train()
        model_D1.to(device)

        model_D2.train()
        model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    scale_min = 0.5
    scale_max = 2.0
    rotate_min = -10
    rotate_max = 10
    ignore_label = 255
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    args.width = w
    args.height = h
    train_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.RandScale([scale_min, scale_max]),
        transforms.RandRotate([rotate_min, rotate_max],
                              padding=IMG_MEAN_RGB,
                              ignore_label=ignore_label),
        transforms.RandomGaussianBlur(),
        transforms.RandomHorizontalFlip(),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='rand',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        #et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
        #et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        #et.ExtRandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])

    val_transform = transforms.Compose([
        # et.ExtResize( 512 ),
        transforms.Crop([args.height + 1, args.width + 1],
                        crop_type='center',
                        padding=IMG_MEAN_RGB,
                        ignore_label=ignore_label),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMG_MEAN, std=[1, 1, 1]),
    ])
    if args.mode != "baseline_tar":
        src_train_dst = OpenRoomsSegmentation(root=args.data_dir,
                                              opt=args,
                                              split='train',
                                              transform=train_transform,
                                              imWidth=args.width,
                                              imHeight=args.height,
                                              remap_labels=args.or_nyu_map)
    else:
        src_train_dst = NYU_Labelled(root=args.data_dir_target,
                                     opt=args,
                                     split='train',
                                     transform=train_transform,
                                     imWidth=args.width,
                                     imHeight=args.height,
                                     phase="TRAIN",
                                     randomize=True)
    tar_train_dst = NYU(root=args.data_dir_target,
                        opt=args,
                        split='train',
                        transform=train_transform,
                        imWidth=args.width,
                        imHeight=args.height,
                        phase="TRAIN",
                        randomize=True,
                        mode=args.mode)
    tar_val_dst = NYU(root=args.data_dir,
                      opt=args,
                      split='val',
                      transform=val_transform,
                      imWidth=args.width,
                      imHeight=args.height,
                      phase="TRAIN",
                      randomize=False)
    trainloader = data.DataLoader(src_train_dst,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(tar_train_dst,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    if args.mode != "baseline" and args.mode != "baseline_tar":
        optimizer_D1 = optim.Adam(model_D1.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D1.zero_grad()

        optimizer_D2 = optim.Adam(model_D2.parameters(),
                                  lr=args.learning_rate_D,
                                  betas=(0.9, 0.99))
        optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1] + 1, input_size[0] + 1),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1] + 1,
                                      input_size_target[0] + 1),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_seg_value1_tar = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_seg_value2_tar = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.zero_grad()
            optimizer_D2.zero_grad()
            adjust_learning_rate_D(optimizer_D1, i_iter)
            adjust_learning_rate_D(optimizer_D2, i_iter)
        sample_src = None
        sample_tar = None
        sample_res_src = None
        sample_res_tar = None
        sample_gt_src = None
        sample_gt_tar = None
        for sub_i in range(args.iter_size):

            # train G
            if args.mode != "baseline" and args.mode != "baseline_tar":
                # don't accumulate grads in D
                for param in model_D1.parameters():
                    param.requires_grad = False

                for param in model_D2.parameters():
                    param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels, _ = batch
            sample_src = images.clone()
            sample_gt_src = labels.clone()

            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            sample_pred_src = pred2.detach().cpu()

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target
            try:
                _, batch = targetloader_iter.__next__()
            except:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, tar_labels, _, labelled = batch
            n_labelled = labelled.sum().detach().item()
            batch_size = images.shape[0]
            sample_tar = images.clone()
            sample_gt_tar = tar_labels.clone()
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            #print("N_labelled {}".format(n_labelled))
            if args.mode == "sda" and n_labelled != 0:
                labelled = labelled.to(device) == 1
                tar_labels = tar_labels.to(device)
                loss_seg1_tar = seg_loss(pred_target1[labelled],
                                         tar_labels[labelled])
                loss_seg2_tar = seg_loss(pred_target2[labelled],
                                         tar_labels[labelled])
                loss_tar_labelled = loss_seg2_tar + args.lambda_seg * loss_seg1_tar
                loss_tar_labelled = loss_tar_labelled / args.iter_size
                loss_seg_value1_tar += loss_seg1_tar.item() / args.iter_size
                loss_seg_value2_tar += loss_seg2_tar.item() / args.iter_size
            else:
                loss_tar_labelled = torch.zeros(
                    1, requires_grad=True).float().to(device)
            # proper normalization
            sample_pred_tar = pred_target2.detach().cpu()
            if args.mode != "baseline" and args.mode != "baseline_tar":
                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_adv_target1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size + loss_tar_labelled
                #loss = loss_tar_labelled
                loss.backward()
                loss_adv_target_value1 += loss_adv_target1.item(
                ) / args.iter_size
                loss_adv_target_value2 += loss_adv_target2.item(
                ) / args.iter_size
                # train D

                # bring back requires_grad
                for param in model_D1.parameters():
                    param.requires_grad = True

                for param in model_D2.parameters():
                    param.requires_grad = True

                # train with source
                pred1 = pred1.detach()
                pred2 = pred2.detach()

                D_out1 = model_D1(F.softmax(pred1))
                D_out2 = model_D2(F.softmax(pred2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

                # train with target
                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()

                D_out1 = model_D1(F.softmax(pred_target1))
                D_out2 = model_D2(F.softmax(pred_target2))

                loss_D1 = bce_loss(
                    D_out1,
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label).to(device))

                loss_D2 = bce_loss(
                    D_out2,
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label).to(device))

                loss_D1 = loss_D1 / args.iter_size / 2
                loss_D2 = loss_D2 / args.iter_size / 2

                loss_D1.backward()
                loss_D2.backward()

                loss_D_value1 += loss_D1.item()
                loss_D_value2 += loss_D2.item()

        optimizer.step()
        if args.mode != "baseline" and args.mode != "baseline_tar":
            optimizer_D1.step()
            optimizer_D2.step()
        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
                'loss_seg1_tar': loss_seg_value1_tar,
                'loss_seg2_tar': loss_seg_value2_tar,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)
            if i_iter % 1000 == 0:
                img = sample_src.cpu()[:, [2, 1, 0], :, :] + torch.from_numpy(
                    np.array(IMG_MEAN_RGB).reshape(1, 3, 1, 1)).float()
                img = img.type(torch.uint8)
                writer.add_images("Src/Images", img, i_iter)
                label = tar_train_dst.decode_target(sample_gt_src).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Labels", label, i_iter)
                preds = sample_pred_src.permute(0, 2, 3, 1).cpu().numpy()
                preds = np.asarray(np.argmax(preds, axis=3), dtype=np.uint8)
                preds = tar_train_dst.decode_target(preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Src/Preds", preds, i_iter)

                tar_img = sample_tar.cpu()[:,
                                           [2, 1, 0], :, :] + torch.from_numpy(
                                               np.array(IMG_MEAN_RGB).reshape(
                                                   1, 3, 1, 1)).float()
                tar_img = tar_img.type(torch.uint8)
                writer.add_images("Tar/Images", tar_img, i_iter)
                tar_label = tar_train_dst.decode_target(
                    sample_gt_tar).transpose(0, 3, 1, 2)
                writer.add_images("Tar/Labels", tar_label, i_iter)
                tar_preds = sample_pred_tar.permute(0, 2, 3, 1).cpu().numpy()
                tar_preds = np.asarray(np.argmax(tar_preds, axis=3),
                                       dtype=np.uint8)
                tar_preds = tar_train_dst.decode_target(tar_preds).transpose(
                    0, 3, 1, 2)
                writer.add_images("Tar/Preds", tar_preds, i_iter)
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_seg1_tar={8:.3f} loss_seg2_tar={9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_seg_value1_tar,
                    loss_seg_value2_tar))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'OR_' + str(args.num_steps_stop) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'OR_' + str(i_iter) + '.pth'))
            if args.mode != "baseline" and args.mode != "baseline_tar":
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D1.pth'))
                torch.save(
                    model_D2.state_dict(),
                    osp.join(args.snapshot_dir,
                             'OR_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
示例#15
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #### restore model_D1, D2 and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D1 parameters
        restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D1)
        model_D1.load_state_dict(saved_state_dict)

        # model_D2 parameters
        restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D2)
        model_D2.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            pdb.set_trace()
            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)
            pdb.set_trace()
            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            pdb.set_trace()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'loss_adv_target1': loss_adv_target_value1,
            'loss_adv_target2': loss_adv_target_value2,
            'loss_D1': loss_D_value1,
            'loss_D2': loss_D_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
示例#16
0
文件: train.py 项目: ccd-ss/ccd
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           vgg16_caffe_path='./model/vgg16_init.pth',
                           pretrained=True)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=2048).to(device)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=1024).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feature, prediction = model(images)
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward()
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feature_target, _ = model(images)
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward()
        loss_adv_target_value = loss_adv_target.item()

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    if args.tensorboard:
        writer.close()
示例#17
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if args.num_classes != 19 and i_parts[1] != 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        # if args.num_classes !=19:
        # print i_parts
        model.load_state_dict(new_params)
    start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
    writer_dir = os.path.join("./logs/", args.name, start_time)
    writer = tensorboard.SummaryWriter(writer_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        MulitviewSegLoader(
            num_classes=args.num_classes,
            root=args.data_dir,
            number_views=2,
            view_idx=1,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size,
            # scale=args.random_scale, mirror=args.random_mirror,
            img_mean=IMG_MEAN),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    trainloader_iter = iter(data_loader_cycle(trainloader))

    targetloader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_target,
            num_classes=args.num_classes,
            number_views=1,
            view_idx=0,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    def mdl_val_func(x):
        return interp_target(model(x)[1])

    targetloader_iter = iter(data_loader_cycle(targetloader))
    val_loader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_val,
            number_views=1,
            view_idx=0,
            num_classes=args.num_classes,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    criterion = CrossEntropyLoss2d().cuda(args.gpu)
    valhelper = ValHelper(gpu=args.gpu,
                          model=mdl_val_func,
                          val_loader=val_loader,
                          loss=criterion,
                          writer=writer)
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            batch = next(trainloader_iter)
            images, labels, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            batch = next(targetloader_iter)
            images, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()
        if i_iter % args.val_steps == 0 and i_iter:
            model.eval()
            log = valhelper.valid_epoch(i_iter)
            print('log: {}'.format(log))
            model.train()

        if i_iter % 10 == 0 and i_iter:
            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))
            writer.add_scalar(f'train/loss_seg_value1', loss_seg_value1,
                              i_iter)
            writer.add_scalar(f'train/loss_seg_value2', loss_seg_value2,
                              i_iter)
            writer.add_scalar(f'train/loss_adv_target_value1',
                              loss_adv_target_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value1', loss_D_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))