Exemplo n.º 1
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        # model.load_state_dict(saved_state_dict)
        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        # model.load_state_dict(new_params)
        model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            # pred1 = interp(pred1)
            # pred2 = interp(pred2)
            pred1 = nn.functional.interpolate(pred1,size=(input_size[1], input_size[0]), mode='bilinear',align_corners=True)
            pred2 = nn.functional.interpolate(pred2, size=(input_size[1], input_size[0]), mode='bilinear',align_corners=True)


            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            # pred_target1 = interp_target(pred_target1)
            # pred_target2 = interp_target(pred_target2)
            pred_target1=nn.functional.interpolate(pred_target1,size=(input_size_target[1], input_size_target[0]), mode='bilinear',align_corners=True)
            pred_target2=nn.functional.interpolate(pred_target2,size=(input_size_target[1], input_size_target[0]), mode='bilinear',align_corners=True)

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1,dim=1))
            D_out2 = model_D2(F.softmax(pred2,dim=1))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print ('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps) + '_D2.pth'))
            show_val(model.state_dict(), i_iter)
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            show_val(model.state_dict(), i_iter)
            zq=1
Exemplo n.º 2
0
def main():
    """Create the model and start the training."""
    model_num = 0
    
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    random.seed(args.random_seed)
    
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        if args.training_option == 1:
            model = Res_Deeplab(num_classes=args.num_classes)
        elif args.training_option == 2:
            model = Res_Deeplab2(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            
        new_params = model.state_dict().copy()
        
        for k, v in saved_state_dict.items():
            print(k)
        
        for k in new_params:
            print(k)
        
        for i in saved_state_dict:
            i_parts = i.split('.')
            
            if '.'.join(i_parts[args.i_parts_index:]) in new_params:
                print("Restored...")
                if args.not_restore_last == True:
                    if not i_parts[args.i_parts_index] == 'layer5' and not i_parts[args.i_parts_index] == 'layer6':
                        new_params['.'.join(i_parts[args.i_parts_index:])] = saved_state_dict[i]                
                else:
                    new_params['.'.join(i_parts[args.i_parts_index:])] = saved_state_dict[i] 
                
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True
    
    writer = SummaryWriter(log_dir = args.snapshot_dir)

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes, extra_layers = args.extra_discriminator_layers)
    model_D2 = FCDiscriminator(num_classes=args.num_classes, extra_layers = args.extra_discriminator_layers)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    
    trainloader = data.DataLoader(sourceDataSet(args.data_dir, 
                                                    args.data_list, 
                                                    max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                    crop_size=input_size,
                                                    random_rotate=False, 
                                                    random_flip=args.augment_1,
                                                    random_lighting=args.augment_1,
                                                    random_blur=args.augment_1,
                                                    random_scaling=args.augment_1,
                                                    mean=IMG_MEAN_SOURCE,
                                                    ignore_label=args.ignore_label),
                                  batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(isprsDataSet(args.data_dir_target, 
                                                args.data_list_target,
                                                max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                crop_size=input_size_target,
                                                random_rotate=False, 
                                                random_flip=args.augment_target,
                                                random_lighting=args.augment_target,
                                                random_blur=args.augment_target,
                                                random_scaling=args.augment_target, 
                                                mean=IMG_MEAN_TARGET,
                                                ignore_label=args.ignore_label),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    
    valloader = data.DataLoader(valDataSet(args.data_dir_val, args.data_list_val, crop_size=input_size_target, mean=IMG_MEAN_TARGET, scale=1, mirror=False),
                                batch_size=1, 
                                shuffle=False, 
                                pin_memory=True)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[0], input_size_target[1]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1
    
    # Which layers to freeze
    non_trainable(args.dont_train, model)
    
    # List saving all best 5 mIoU's 
    best_mIoUs = [0.0, 0.0, 0.0, 0.0, 0.0]

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source
            
            while True:
                try:

                    _, batch = next(trainloader_iter)
                    images, labels, _, train_name = batch
                    #print(train_name)
                    images = Variable(images).cuda(args.gpu)

                    pred1, pred2 = model(images)
                    pred1 = interp(pred1)
                    pred2 = interp(pred2)

                    loss_seg1 = loss_calc(pred1, labels, args.gpu, args.ignore_label, train_name, weights)
                    loss_seg2 = loss_calc(pred2, labels, args.gpu, args.ignore_label, train_name, weights)
                    
                    
                    loss = loss_seg2 + args.lambda_seg * loss_seg1

                    # proper normalization
                    loss = loss / args.iter_size
                    loss.backward()

                    if isinstance(loss_seg1.data.cpu().numpy(), list): 
                        loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size
                    else: 
                        loss_seg_value1 += loss_seg1.data.cpu().numpy()/ args.iter_size

                    if isinstance(loss_seg2.data.cpu().numpy(), list): 
                        loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size
                    else: 
                        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size
                    break
                except (RuntimeError, AssertionError, AttributeError):
                    continue
             
            if args.experiment == 1:
                # Which layers to freeze
                non_trainable('0', model)
            
            # train with target
            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
                        
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            
            if isinstance(loss_adv_target1.data.cpu().numpy(), list): 
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy()[0] / args.iter_size
            else: 
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
                
            if isinstance(loss_adv_target2.data.cpu().numpy(), list): 
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy()[0] / args.iter_size
            else: 
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size
            
            if args.experiment == 1:
                # Which layers to freeze
                non_trainable(args.dont_train, model)


            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()
            
            if isinstance(loss_D1.data.cpu().numpy(), list): 
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else: 
                loss_D_value1 += loss_D1.data.cpu().numpy()
                
            if isinstance(loss_D2.data.cpu().numpy(), list): 
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else: 
                loss_D_value2 += loss_D2.data.cpu().numpy()
            

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            if isinstance(loss_D1.data.cpu().numpy(), list): 
                loss_D_value1 += loss_D1.data.cpu().numpy()[0]
            else: 
                loss_D_value1 += loss_D1.data.cpu().numpy()
                
            if isinstance(loss_D2.data.cpu().numpy(), list): 
                loss_D_value2 += loss_D2.data.cpu().numpy()[0]
            else: 
                loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()



        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            if model_num != args.num_models_keep:
                torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(model_num) + '.pth'))
                torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(model_num) + '_D1.pth'))
                torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_' + str(model_num) + '_D2.pth'))
                model_num = model_num +1
            if model_num == args.num_models_keep:
                model_num = 0

        # Validation
        if (i_iter % args.val_every== 0 and i_iter != 0) or i_iter == 1:
            mIoU = validation(valloader, model, interp_target, writer, i_iter, [37,41,10])
            for i in range(0, len(best_mIoUs)):
                if best_mIoUs[i] < mIoU:
                    torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'bestmodel_' + str(i) + '.pth'))
                    torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'bestmodel_' + str(i) + '_D1.pth'))
                    torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'bestmodel_' + str(i) + '_D2.pth'))
                    best_mIoUs.append(mIoU)
                    print("Saved model at iteration %d as the best %d" % (i_iter, i))
                    best_mIoUs.sort(reverse = True)
                    best_mIoUs = best_mIoUs[:5]
                    break
            

            
            
        # Save for tensorboardx
        writer.add_scalar('loss_seg_value1', loss_seg_value1, i_iter)
        writer.add_scalar('loss_seg_value2', loss_seg_value2, i_iter)
        writer.add_scalar('loss_adv_target_value1', loss_adv_target_value1, i_iter)
        writer.add_scalar('loss_adv_target_value2', loss_adv_target_value2, i_iter)
        writer.add_scalar('loss_D_value1', loss_D_value1, i_iter)
        writer.add_scalar('loss_D_value2', loss_D_value2, i_iter)
        
        
    writer.close()