Пример #1
0
def main():

    args = get_arguments()
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args.cuda_device_id)


    if args.model == 'ResNet':
            model = DeeplabMulti(num_classes=args.num_classes)
    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)



    device = torch.device("cuda" if not args.cpu else "cpu")

    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.to(device)  
    

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, crop_size=(1024, 512),scale=False, mirror=False, mean=IMG_MEAN),
        batch_size=1, shuffle=False,  pin_memory=True)

    count_class = np.zeros((19, 1))
    class_center_temp = np.zeros((19, 256))

    for index, batch in enumerate(trainloader):
        if index % 100 == 0:
            print( '%d processd' % index)
        images, labels, _, _= batch
        images = images.to(device)
        labels = labels.long().to(device)

        with torch.no_grad():
            feature, _ = model(images)

        class_center,count_class_t = class_center_precal(feature,labels)
        count_class = count_class + count_class_t.numpy()
        class_center_temp += class_center.cpu().data[0].numpy()

    
    count_class[count_class==0] = 1              #in case divide 0 error
    
    class_center = class_center_temp/count_class
    np.save('./source_center.npy',class_center)
Пример #2
0
def get_source_train_dataloader(args):

    source_loader = DataLoader(GTA5DataSet(args.source_dir,
                                           args.type,
                                           args.source_list,
                                           max_iters=args.num_steps *
                                           args.iter_size * args.batch_size,
                                           crop_size=args.images_size,
                                           scale=True,
                                           mirror=True,
                                           mean=args.mean),
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=args.num_workers,
                               pin_memory=True)

    #trainloader_iter = enumerate(trainloader)
    return source_loader
Пример #3
0
def main():
    """Create the model and start the training."""
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    mkdirs(osp.join("logs/"+args.exp_name))

    logger = create_logger('global_logger', "logs/" + args.exp_name + '/log.txt')
    logger.info('{}'.format(args))
##############################

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    logger.info("random_scale {}".format(args.random_scale))
    logger.info("is_training {}".format(args.is_training))

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)
    print(type(input_size_target[1]))
    cudnn.enabled = True
    args.snapshot_dir = args.snapshot_dir + args.exp_name
    tb_logger = SummaryWriter("logs/"+args.exp_name)
##############################

#validation data
    local_array = np.load("local.npy")
    local_array = local_array[:,:,:19]
    local_array = local_array / local_array.sum(2).reshape(512, 1024, 1)
    local_array = local_array.transpose(2,0,1)
    local_array = torch.from_numpy(local_array)
    local_array = local_array.view(1, 19, 512, 1024)
    h, w = map(int, args.input_size_test.split(','))
    input_size_test = (h,w)
    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)
    h, w = map(int, args.input_size_crop.split(','))
    input_size_crop = h,w
    h,w = map(int, args.input_size_target_crop.split(','))
    input_size_target_crop = h,w

    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]


    normalize_module = transforms_seg.Normalize(mean=mean,
                                                std=std)

    test_normalize = transforms.Normalize(mean=mean,
                                          std=std)

    test_transform = transforms.Compose([
                         transforms.Resize((input_size_test[1], input_size_test[0])),
                         transforms.ToTensor(),
                         test_normalize])

    valloader = data.DataLoader(cityscapesDataSet(
                                       args.data_dir_target,
                                       args.data_list_target_val,
                                       crop_size=input_size_test,
                                       set='train',
                                       transform=test_transform),num_workers=args.num_workers,
                                 batch_size=1, shuffle=False, pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list_val = args.label_path_list_val
    gt_imgs_val = open(label_path_list_val, 'r').read().splitlines()
    gt_imgs_val = [osp.join(args.data_dir_target_val, x) for x in gt_imgs_val]


    name_classes = np.array(info['label'], dtype=np.str)
    interp_val = nn.Upsample(size=(com_size[1], com_size[0]),mode='bilinear', align_corners=True)

    ####
    #build model
    ####
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        num_class=args.num_classes,
        weights=args.weights_decoder,
        use_aux=True)

    weighted_softmax = pd.read_csv("weighted_loss.txt", header=None)
    weighted_softmax = weighted_softmax.values
    weighted_softmax = torch.from_numpy(weighted_softmax)
    weighted_softmax = weighted_softmax / torch.sum(weighted_softmax)
    weighted_softmax = weighted_softmax.cuda().float()
    model = SegmentationModule(
        net_encoder, net_decoder, args.use_aux)

    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
    model.cuda()

    nets = (net_encoder, net_decoder, None, None)
    optimizers = create_optimizer(nets, args)
    cudnn.enabled=True
    cudnn.benchmark=True
    model.train()



    mean_mapping = [0.485, 0.456, 0.406]
    mean_mapping = [item * 255 for item in mean_mapping]

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    source_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size[1], input_size[0]]),
                             #segtransforms.RandScale((0.75, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             #segtransforms.RandomHorizontalFlip(),
                             #segtransforms.Crop([input_size_crop[1], input_size_crop[0]], crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             normalize_module])


    target_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size_target[1], input_size_target[0]]),
                             #segtransforms.RandScale((0.75, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             #segtransforms.RandomHorizontalFlip(),
                             #segtransforms.Crop([input_size_target_crop[1], input_size_target_crop[0]],crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             normalize_module])
    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size, transform = source_transform),
        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(fake_cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     set=args.set,
                                                     transform=target_transform),
                                   batch_size=args.batch_size, shuffle=True, num_workers=5,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting


    criterion_seg = torch.nn.CrossEntropyLoss(ignore_index=255,reduce=False)
    criterion_pseudo = torch.nn.BCEWithLogitsLoss(reduce=False).cuda()
    bce_loss = torch.nn.BCEWithLogitsLoss().cuda()
    criterion_reconst = torch.nn.L1Loss().cuda()
    criterion_soft_pseudo = torch.nn.MSELoss(reduce=False).cuda()
    criterion_box = torch.nn.CrossEntropyLoss(ignore_index=255, reduce=False)
    interp = nn.Upsample(size=(input_size[1], input_size[0]),align_corners=True, mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), align_corners=True, mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1


    optimizer_encoder, optimizer_decoder, optimizer_disc, optimizer_reconst = optimizers
    batch_time = AverageMeter(10)
    loss_seg_value1 = AverageMeter(10)
    best_mIoUs = 0
    best_test_mIoUs = 0
    loss_seg_value2 = AverageMeter(10)
    loss_reconst_source_value = AverageMeter(10)
    loss_reconst_target_value = AverageMeter(10)
    loss_balance_value = AverageMeter(10)
    loss_eq_att_value = AverageMeter(10)
    loss_pseudo_value = AverageMeter(10)
    bounding_num = AverageMeter(10)
    pseudo_num = AverageMeter(10)
    loss_bbx_att_value = AverageMeter(10)

    for i_iter in range(args.num_steps):
        # train G

        # don't accumulate grads in D

        end = time.time()
        _, batch = trainloader_iter.__next__()
        images, labels, _ = batch
        images = Variable(images).cuda(async=True)
        labels = Variable(labels).cuda(async=True)
        results  = model(images, labels)
        loss_seg2 = results[-2]
        loss_seg1 = results[-1]

        loss_seg2 = torch.mean(loss_seg2)
        loss_seg1 = torch.mean(loss_seg1)
        loss = args.lambda_trade_off*(loss_seg2+args.lambda_seg * loss_seg1)
        # proper normalization
        #logger.info(loss_seg1.data.cpu().numpy())
        loss_seg_value2.update(loss_seg2.data.cpu().numpy())
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        _, batch = targetloader_iter.__next__()
        images, fake_labels, _ = batch
        images = Variable(images).cuda(async=True)
        fake_labels = Variable(fake_labels, requires_grad=False).cuda()
        results = model(images, None)
        target_seg = results[0]
        conf_tea, pseudo_label = torch.max(nn.functional.softmax(target_seg), dim=1)
        pseudo_label = pseudo_label.detach()
        # pseudo label hard
        loss_pseudo = criterion_seg(target_seg, pseudo_label)
        fake_mask = (fake_labels!=255).float().detach()
        conf_mask = torch.gt(conf_tea, args.conf_threshold).float().detach()
        loss_pseudo = loss_pseudo * conf_mask.detach() * fake_mask.detach()
        loss_pseudo = loss_pseudo.view(-1)
        loss_pseudo = loss_pseudo[loss_pseudo!=0]
        #loss_pseudo = torch.sum(loss_pseudo * conf_mask.detach() * fake_mask.detach())
        predict_class_mean = torch.mean(nn.functional.softmax(target_seg), dim=0).mean(1).mean(1)
        equalise_cls_loss = robust_binary_crossentropy(predict_class_mean, weighted_softmax)
        #equalise_cls_loss = torch.mean(equalise_cls_loss)* args.num_classes * torch.sum(conf_mask * fake_mask) / float(input_size_crop[0] * input_size_crop[1] * args.batch_size)
        # new equalise_cls_loss
        equalise_cls_loss = torch.mean(equalise_cls_loss)
        #loss=args.lambda_balance * equalise_cls_loss
        #bbx attention
        loss_bbx_att = []
        loss_eq_att = []
        for box_idx, box_size in enumerate(args.box_size):
            pooling = torch.nn.AvgPool2d(box_size)
            pooling_result_i = pooling(target_seg)
            local_i = pooling(local_array).float().cuda()
            pooling_conf_mask, pooling_pseudo = torch.max(nn.functional.softmax(pooling_result_i), dim=1)
            pooling_conf_mask = torch.gt(pooling_conf_mask, args.conf_threshold).float().detach()
            fake_mask_i = pooling(fake_labels.unsqueeze(1).float())
            fake_mask_i = fake_mask_i.squeeze(1)
            fake_mask_i = (fake_mask_i!=255).float().detach()
            loss_bbx_att_i = criterion_seg(pooling_result_i, pooling_pseudo)
            loss_bbx_att_i = loss_bbx_att_i * pooling_conf_mask * fake_mask_i
            loss_bbx_att_i = loss_bbx_att_i.view(-1)
            loss_bbx_att_i = loss_bbx_att_i[loss_bbx_att_i!=0]
            loss_bbx_att.append(loss_bbx_att_i)
            pooling_result_i = pooling_result_i.mean(0).unsqueeze(0)

            equalise_cls_loss_i = robust_binary_crossentropy(nn.functional.softmax(pooling_result_i), local_i)
            equalise_cls_loss_i = equalise_cls_loss_i.mean(1)
            equalise_cls_loss_i = equalise_cls_loss_i * pooling_conf_mask * fake_mask_i
            equalise_cls_loss_i = equalise_cls_loss_i.view(-1)
            equalise_cls_loss_i = equalise_cls_loss_i[equalise_cls_loss_i!=0]
            loss_eq_att.append(equalise_cls_loss_i)


        if len(args.box_size) > 0:
            if args.merge_1x1:
                loss_bbx_att.append(loss_pseudo)
            loss_bbx_att = torch.cat(loss_bbx_att, dim=0)
            bounding_num.update(loss_bbx_att.size(0) / float(560*480*args.batch_size))
            loss_bbx_att = torch.mean(loss_bbx_att)

            loss_eq_att = torch.cat(loss_eq_att, dim=0)
            loss_eq_att = torch.mean(loss_eq_att)

            loss_eq_att_value.update(loss_eq_att.item())
        else:
            loss_bbx_att = torch.mean(loss_pseudo)
            loss_eq_att = 0


        pseudo_num.update(loss_pseudo.size(0) / float(560*480*args.batch_size))
        loss_pseudo = torch.mean(loss_pseudo)
        if not args.merge_1x1:
            loss += args.lambda_pseudo * loss_pseudo
        loss = args.lambda_balance * equalise_cls_loss
        if not isinstance(loss_bbx_att, list):
            loss += args.lambda_pseudo * loss_bbx_att
        loss += args.lambda_eq * loss_eq_att
        loss_pseudo_value.update(loss_pseudo.item())
        loss_balance_value.update(equalise_cls_loss.item())


        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()
        #optimizer_disc.step()
        #loss_target_disc_value.update(loss_target_disc.data.cpu().numpy())




        batch_time.update(time.time() - end)

        remain_iter = args.num_steps - i_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))



        if i_iter == args.decrease_lr:
            adjust_learning_rate(optimizer_encoder, i_iter, args.lr_encoder, args)
            adjust_learning_rate(optimizer_decoder, i_iter, args.lr_decoder, args)
        if i_iter % args.print_freq == 0:
            lr_encoder = optimizer_encoder.param_groups[0]['lr']
            lr_decoder = optimizer_decoder.param_groups[0]['lr']
            logger.info('exp = {}'.format(args.snapshot_dir))
            logger.info('Iter = [{0}/{1}]\t'
                        'Time = {batch_time.avg:.3f}\t'
                        'loss_seg1 = {loss_seg1.avg:4f}\t'
                        'loss_seg2 = {loss_seg2.avg:.4f}\t'
                        'loss_reconst_source = {loss_reconst_source.avg:.4f}\t'
                        'loss_bbx_att = {loss_bbx_att.avg:.4f}\t'
                        'loss_reconst_target = {loss_reconst_target.avg:.4f}\t'
                        'loss_pseudo = {loss_pseudo.avg:.4f}\t'
                        'loss_eq_att = {loss_eq_att.avg:.4f}\t'
                        'loss_balance = {loss_balance.avg:.4f}\t'
                        'bounding_num = {bounding_num.avg:.4f}\t'
                        'pseudo_num = {pseudo_num.avg:4f}\t'
                        'lr_encoder = {lr_encoder:.8f} lr_decoder = {lr_decoder:.8f}'.format(
                         i_iter, args.num_steps, batch_time=batch_time,
                         loss_seg1=loss_seg_value1, loss_seg2=loss_seg_value2,
                         loss_pseudo=loss_pseudo_value,
                         loss_bbx_att = loss_bbx_att_value,
                         bounding_num = bounding_num,
                         loss_eq_att = loss_eq_att_value,
                         pseudo_num = pseudo_num,
                         loss_reconst_source=loss_reconst_source_value,
                         loss_balance=loss_balance_value,
                         loss_reconst_target=loss_reconst_target_value,
                         lr_encoder=lr_encoder,
                         lr_decoder=lr_decoder))


            logger.info("remain_time: {}".format(remain_time))
            if not tb_logger is None:
                tb_logger.add_scalar('loss_seg_value1', loss_seg_value1.avg, i_iter)
                tb_logger.add_scalar('loss_seg_value2', loss_seg_value2.avg, i_iter)
                tb_logger.add_scalar('bounding_num', bounding_num.avg, i_iter)
                tb_logger.add_scalar('pseudo_num', pseudo_num.avg, i_iter)
                tb_logger.add_scalar('loss_pseudo', loss_pseudo_value.avg, i_iter)
                tb_logger.add_scalar('lr', lr_encoder, i_iter)
                tb_logger.add_scalar('loss_balance', loss_balance_value.avg, i_iter)
            #####
            #save image result
            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                logger.info('taking snapshot ...')
                model.eval()

                val_time = time.time()
                hist = np.zeros((19,19))
                # f = open(args.result_dir, 'a')
                # for index, batch in tqdm(enumerate(testloader)):
                #     with torch.no_grad():
                #         image, name = batch
                #         results = model(Variable(image).cuda(), None)
                #         output2 = results[0]
                #         pred = interp_val(output2)
                #         del output2
                #         pred = pred.cpu().data[0].numpy()
                #         pred = pred.transpose(1, 2, 0)
                #         pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                #         label = np.array(Image.open(gt_imgs_val[index]))
                #         #label = np.array(label.resize(com_size, Image.
                #         label = label_mapping(label, mapping)
                #         #logger.info(label.shape)
                #         hist += fast_hist(label.flatten(), pred.flatten(), 19)
                # mIoUs = per_class_iu(hist)
                # for ind_class in range(args.num_classes):
                #     logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                #     tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)


                # logger.info(mIoUs)
                # tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                # tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                # f.write('i_iter:{:d},\tmiou:{:0.3f} \n'.format(i_iter, mIoUs))
                # f.close()
                # if mIoUs > best_mIoUs:
                is_best = True
                # best_mIoUs = mIoUs
                #test validation
                model.eval()
                val_time = time.time()
                hist = np.zeros((19,19))
                # f = open(args.result_dir, 'a')
                for index, batch in tqdm(enumerate(valloader)):
                    with torch.no_grad():
                        image, name = batch
                        results = model(Variable(image).cuda(), None)
                        output2 = results[0]
                        pred = interp_val(output2)
                        del output2
                        pred = pred.cpu().data[0].numpy()
                        pred = pred.transpose(1, 2, 0)
                        pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                        label = np.array(Image.open(gt_imgs_val[index]))
                        #label = np.array(label.resize(com_size, Image.
                        label = label_mapping(label, mapping)
                        #logger.info(label.shape)
                        hist += fast_hist(label.flatten(), pred.flatten(), 19)
                mIoUs = per_class_iu(hist)
                for ind_class in range(args.num_classes):
                    logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                    tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)

                mIoUs = round(np.nanmean(mIoUs) *100, 2)
                is_best_test = False
                logger.info(mIoUs)
                tb_logger.add_scalar('test mIoU', mIoUs, i_iter)
                if mIoUs > best_test_mIoUs:
                    best_test_mIoUs = mIoUs
                    is_best_test = True
                # logger.info("best mIoU {}".format(best_mIoUs))
                logger.info("best test mIoU {}".format(best_test_mIoUs))
                net_encoder, net_decoder, net_disc, net_reconst = nets
                save_checkpoint(net_encoder, 'encoder', i_iter, args, is_best_test)
                save_checkpoint(net_decoder, 'decoder', i_iter, args, is_best_test)
                is_best_test = False
            model.train()
Пример #4
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    #cudnn.enabled = True
    gpu = args.gpu

    # Create network
    #if args.model == 'DeepLab':
    #    model = DeeplabMulti(num_classes=args.num_classes)
    #    if args.restore_from[:4] == 'http' :
    #        saved_state_dict = model_zoo.load_url(args.restore_from)
    #    else:
    #        saved_state_dict = torch.load(args.restore_from)

    #    new_params = model.state_dict().copy()
    #    for i in saved_state_dict:
    #        # Scale.layer5.conv2d_list.3.weight
    #        i_parts = i.split('.')
    #        # print i_parts
    #        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
    #            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    # print i_parts
    #    model.load_state_dict(new_params)

    #model.train()
    #model.cuda(args.gpu)

    #cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(0)
    #model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = FCDiscriminator(num_classes=args.num_classes)

    #model_D1.train()
    #model_D1.cuda(args.gpu)

    #model_D2.train()
    #model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    gta_trainloader = data.DataLoader(GTA5DataSet(
        args.data_dir,
        args.data_list,
        args.num_classes,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    gta_trainloader_iter = enumerate(gta_trainloader)

    gta_valloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                                args.valdata_list,
                                                args.num_classes,
                                                max_iters=None,
                                                crop_size=input_size,
                                                scale=args.random_scale,
                                                mirror=args.random_mirror,
                                                mean=IMG_MEAN),
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.num_workers,
                                    pin_memory=True)

    gta_valloader_iter = enumerate(gta_valloader)

    cityscapes_targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        args.num_classes,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set='train'),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    cityscapes_targetloader_iter = enumerate(cityscapes_targetloader)

    cityscapes_valtargetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.valdata_list_target,
        args.num_classes,
        max_iters=None,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set='val'),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.num_workers,
                                                 pin_memory=True)

    cityscapes_valtargetloader_iter = enumerate(cityscapes_valtargetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    #optimizer = optim.SGD(model.optim_parameters(args),
    #                      lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    #optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    #optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    #optimizer_D1.zero_grad()

    #optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    #optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    #interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_D_value = 0
        #loss_seg_value1 = 0
        #loss_adv_target_value1 = 0
        #loss_D_value1 = 0

        #loss_seg_value2 = 0
        #loss_adv_target_value2 = 0
        #loss_D_value2 = 0

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        _, batch = gta_trainloader_iter.next()
        images, labels, _, _ = batch
        size = labels.size()
        #print(size)
        #labels = Variable(labels)
        oneHot_size = (size[0], args.num_classes, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1, labels.long(), 1.0)
        #print(input_label.size())
        labels1 = Variable(input_label).cuda(0)
        #D_out1 = model_D(labels)

        #print(D_out1.data.size())
        #loss_out1 = bce_loss(D_out1, Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(0))

        _, batch = cityscapes_targetloader_iter.next()
        images, labels, _, _ = batch
        size = labels.size()
        #labels = Variable(labels)
        oneHot_size = (size[0], args.num_classes, size[2], size[3])
        input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()

        input_label = input_label.scatter_(1, labels.long(), 1.0)
        labels2 = Variable(input_label).cuda(0)

        #print(labels1.data.size())
        #print(labels2.data.size())
        labels = torch.cat((labels1, labels2), 0)
        #print(labels.data.size())

        #D_out2 = model_D(labels)
        D_out = model_D(labels)
        #print(D_out.data.size())

        target_size = D_out.data.size()
        target_labels1 = torch.FloatTensor(
            torch.Size((target_size[0] / 2, target_size[1], target_size[2],
                        target_size[3]))).fill_(source_label)
        target_labels2 = torch.FloatTensor(
            torch.Size((target_size[0] / 2, target_size[1], target_size[2],
                        target_size[3]))).fill_(target_label)
        target_labels = torch.cat((target_labels1, target_labels2), 0)
        target_labels = Variable(target_labels).cuda(0)
        #print(target_labels.data.size())
        loss_out = bce_loss(D_out, target_labels)

        loss = loss_out / args.iter_size
        loss.backward()

        loss_D_value += loss_out.data.cpu().numpy()[0] / args.iter_size

        #print(loss_D_value)
        optimizer_D.step()

        #print('exp = {}'.format(args.snapshot_dir))
        if i_iter % 100 == 0:
            print('iter = {0:8d}/{1:8d}, loss_D = {2:.3f}'.format(
                i_iter, args.num_steps, loss_D_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'Classify_' + str(args.num_steps) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'Classify_' + str(i_iter) + '.pth'))

            model_D.eval()
            loss_valD_value = 0
            correct = 0
            wrong = 0
            for i, (images, labels, _, _) in enumerate(gta_valloader):
                #if i > 500:
                #    break
                size = labels.size()
                #labels = Variable(labels)
                oneHot_size = (size[0], args.num_classes, size[2], size[3])
                input_label = torch.FloatTensor(
                    torch.Size(oneHot_size)).zero_()
                input_label = input_label.scatter_(1, labels.long(), 1.0)
                labels = Variable(input_label).cuda(0)
                D_out1 = model_D(labels)
                loss_out1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label)).cuda(0))
                loss_valD_value += loss_out1.data.cpu().numpy()[0]
                correct = correct + (D_out1.data.cpu() < 0).sum() / 100
                wrong = wrong + (D_out1.data.cpu() >= 0).sum() / 100
                #accuracy = 1.0 * correct / (wrong + correct)
                #print('accuracy:%f' % accuracy)
                #print(correct)
                #print(wrong)

            for i, (images, labels, _,
                    _) in enumerate(cityscapes_valtargetloader):
                #if i > 500:
                #    break
                size = labels.size()
                #labels = Variable(labels)
                oneHot_size = (size[0], args.num_classes, size[2], size[3])
                input_label = torch.FloatTensor(
                    torch.Size(oneHot_size)).zero_()
                input_label = input_label.scatter_(1, labels.long(), 1.0)
                labels = Variable(input_label).cuda(0)
                D_out2 = model_D(labels)
                loss_out2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(target_label)).cuda(0))
                loss_valD_value += loss_out2.data.cpu().numpy()[0]
                wrong = wrong + (D_out2.data.cpu() < 0).sum()
                correct = correct + (D_out2.data.cpu() >= 0).sum()
            accuracy = 1.0 * correct / (wrong + correct)
            print('accuracy:%f' % accuracy)

            print('iter = {0:8d}/{1:8d}, loss_valD = {2:.3f}'.format(
                i_iter, args.num_steps, loss_valD_value))

            model_D.train()
Пример #5
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    if os.path.exists(args.snapshot_dir) == False:
        os.mkdir(args.snapshot_dir)
    f = open(args.snapshot_dir + 'GTA2Cityscapes_log.txt', 'w')

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)
    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    # Create network
    student_net = SEANet(num_classes=args.num_classes)
    teacher_net = SEANet(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)

    new_params = student_net.state_dict().copy()
    for i, j in zip(saved_state_dict, new_params):
        if (i[0] != 'f') & (i[0] != 's') & (i[0] != 'u'):
            new_params[j] = saved_state_dict[i]

    student_net.load_state_dict(new_params)
    teacher_net.load_state_dict(new_params)

    for name, param in teacher_net.named_parameters():
        param.requires_grad = False

    teacher_net = teacher_net.cuda()
    student_net = student_net.cuda()

    src_loader = data.DataLoader(GTA5DataSet(args.data_dir_source,
                                             args.data_list_source,
                                             crop_size=input_size,
                                             scale=False,
                                             mirror=False,
                                             mean=IMG_MEAN),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_loader = data.DataLoader(cityscapesDataSet(args.data_dir_target,
                                                   args.data_list_target,
                                                   max_iters=24966,
                                                   crop_size=input_size,
                                                   scale=False,
                                                   mirror=False,
                                                   mean=IMG_MEAN,
                                                   set='val'),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    val_loader = data.DataLoader(cityscapesDataSet(args.data_dir_target,
                                                   args.data_list_target,
                                                   max_iters=None,
                                                   crop_size=input_size,
                                                   scale=False,
                                                   mirror=False,
                                                   mean=IMG_MEAN,
                                                   set='val'),
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    num_batches = min(len(src_loader), len(tgt_loader))

    optimizer = optim.Adam(student_net.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    optimizer.zero_grad()

    student_params = list(student_net.parameters())
    teacher_params = list(teacher_net.parameters())

    teacher_optimizer = WeightEMA(
        teacher_params,
        student_params,
        alpha=args.teacher_alpha,
    )

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    n_class = args.num_classes
    num_steps = args.num_epoch * num_batches
    loss_hist = np.zeros((num_steps, 5))
    index_i = -1
    OA_hist = 0.2
    aug_loss = torch.nn.MSELoss()

    for epoch in range(args.num_epoch):
        if epoch == 6:
            return
        for batch_index, (src_data,
                          tgt_data) in enumerate(zip(src_loader, tgt_loader)):
            index_i += 1

            tem_time = time.time()
            student_net.train()
            optimizer.zero_grad()

            # train with source
            images, src_label, _, im_name = src_data
            images = images.cuda()
            src_label = src_label.cuda()
            _, src_output = student_net(images)
            src_output = interp(src_output)
            # Segmentation Loss
            cls_loss_value = loss_calc(src_output, src_label)
            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            metrics_batch = []
            for lt, lp in zip(lbl_true, lbl_pred):
                _, _, mean_iu, _ = label_accuracy_score(
                    lt, lp, n_class=args.num_classes)
                metrics_batch.append(mean_iu)
            miu = np.mean(metrics_batch, axis=0)

            # train with target
            images, label_target, _, im_name = tgt_data
            images = images.cuda()
            label_target = label_target.cuda()
            tgt_t_input = images + torch.randn(
                images.size()).cuda() * args.noise
            tgt_s_input = images + torch.randn(
                images.size()).cuda() * args.noise

            _, tgt_s_output = student_net(tgt_s_input)
            t_confidence, tgt_t_output = teacher_net(tgt_t_input)

            t_confidence = t_confidence.squeeze()

            # self-ensembling Loss
            tgt_t_predicts = F.softmax(tgt_t_output,
                                       dim=1).transpose(1, 2).transpose(2, 3)
            tgt_s_predicts = F.softmax(tgt_s_output,
                                       dim=1).transpose(1, 2).transpose(2, 3)

            mask = t_confidence > args.attention_threshold
            mask = mask.view(-1)
            num_pixel = mask.shape[0]

            mask_rate = torch.sum(mask).float() / num_pixel

            tgt_s_predicts = tgt_s_predicts.contiguous().view(-1, n_class)
            tgt_s_predicts = tgt_s_predicts[mask]
            tgt_t_predicts = tgt_t_predicts.contiguous().view(-1, n_class)
            tgt_t_predicts = tgt_t_predicts[mask]
            aug_loss_value = aug_loss(tgt_s_predicts, tgt_t_predicts)
            aug_loss_value = args.st_weight * aug_loss_value

            # TOTAL LOSS
            if mask_rate == 0.:
                aug_loss_value = torch.tensor(0.).cuda()

            total_loss = cls_loss_value + aug_loss_value

            total_loss.backward()
            loss_hist[index_i, 0] = total_loss.item()
            loss_hist[index_i, 1] = cls_loss_value.item()
            loss_hist[index_i, 2] = aug_loss_value.item()
            loss_hist[index_i, 3] = miu

            optimizer.step()
            teacher_optimizer.step()
            batch_time = time.time() - tem_time

            if (batch_index + 1) % 10 == 0:
                print(
                    'epoch %d/%d:  %d/%d time: %.2f miu = %.1f cls_loss = %.3f st_loss = %.3f \n'
                    % (epoch + 1, args.num_epoch, batch_index + 1, num_batches,
                       batch_time,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 3]) * 100,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 1]),
                       np.mean(loss_hist[index_i - 9:index_i + 1, 2])))
                f.write(
                    'epoch %d/%d:  %d/%d time: %.2f miu = %.1f cls_loss = %.3f st_loss = %.3f \n'
                    % (epoch + 1, args.num_epoch, batch_index + 1, num_batches,
                       batch_time,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 3]) * 100,
                       np.mean(loss_hist[index_i - 9:index_i + 1, 1]),
                       np.mean(loss_hist[index_i - 9:index_i + 1, 2])))
                f.flush()

            if (batch_index + 1) % 500 == 0:
                OA_new = test_mIoU(f,
                                   teacher_net,
                                   val_loader,
                                   epoch + 1,
                                   input_size_target,
                                   print_per_batches=10)

                # Saving the models
                if OA_new > OA_hist:
                    f.write('Save Model\n')
                    print('Save Model')
                    model_name = 'GTA2Cityscapes_epoch' + repr(
                        epoch + 1) + 'batch' + repr(batch_index +
                                                    1) + 'tgt_miu_' + repr(
                                                        int(OA_new *
                                                            1000)) + '.pth'
                    torch.save(teacher_net.state_dict(),
                               os.path.join(args.snapshot_dir, model_name))
                    OA_hist = OA_new

    f.close()
    torch.save(teacher_net.state_dict(),
               os.path.join(args.snapshot_dir, 'GTA_TeacherNet.pth'))
    np.savez(args.snapshot_dir + 'GTA_loss.npz', loss_hist=loss_hist)
Пример #6
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    model = Res_Deeplab(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    if os.path.isfile(citys_feat_distr_path) == False:
        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=CITY_IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        # interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
        interp_down = nn.Upsample(size=(16, 32),
                                  mode='bilinear',
                                  align_corners=True)
        citys_feat_distrs = []
        citys_img_paths = []
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd of %d' % (index, len(testloader)))
            image, _, name = batch
            output1, output2 = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp_down(output2).cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)

            output = output[np.newaxis, :]  # add a dim
            citys_feat_distrs.extend(output)
            citys_img_paths.extend(name)

            #name: 'frankfurt/frankfurt_000001_007973_leftImg8bit.png'
            # name = name[0].split('/')[-1]
        citys_feat_distrs_np = np.array(citys_feat_distrs)
        citys_img_paths_np = np.array(citys_img_paths)
        np.save(citys_feat_distr_path, citys_feat_distrs_np)
        np.save(citys_imgpaths_path, citys_img_paths_np)
    else:
        citys_feat_distrs_np = np.load(citys_feat_distr_path)
        citys_img_paths_np = np.load(citys_imgpaths_path)

    if os.path.isfile(gta_feat_distr_path) == False:
        gtaloader = data.DataLoader(GTA5DataSet(GTA_DATA_DIRECTORY,
                                                GTA_DATA_LIST_PATH,
                                                crop_size=(1024, 512),
                                                mean=GTA_IMG_MEAN,
                                                scale=False,
                                                mirror=False),
                                    batch_size=1,
                                    shuffle=False,
                                    pin_memory=True)

        interp_down = nn.Upsample(size=(16, 32),
                                  mode='bilinear',
                                  align_corners=True)
        gta_feat_distrs = []
        gta_img_paths = []
        for index, batch in enumerate(gtaloader):
            if index % 100 == 0:
                print('%d processd of %d' % (index, len(gtaloader)))
            image, _, _, name = batch
            output1, output2 = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp_down(output2).cpu().data[0].numpy()
            output = output.transpose(1, 2, 0)

            output = output[np.newaxis, :]  # add a dim
            gta_feat_distrs.extend(output)
            gta_img_paths.extend(name)

        gta_feat_distrs_np = np.array(gta_feat_distrs)
        gta_img_paths_np = np.array(gta_img_paths)
        np.save(gta_feat_distr_path, gta_feat_distrs_np)
        np.save(gta_imgpaths_path, gta_img_paths_np)
    else:
        gta_feat_distrs_np = np.load(gta_feat_distr_path)
        gta_img_paths_np = np.load(gta_imgpaths_path)

    if os.path.isfile(closest_imgs_path) == False:
        temp_feat = citys_feat_distrs_np[0, :]
        # [m,n,c]=temp_feat.shape
        pixel_amount = temp_feat.size
        closest_imgs_locs = []
        for i in range(citys_img_paths_np.shape[0]):
            cur_citys_feat = citys_feat_distrs_np[i, :]
            distances = []
            if i % 10 == 0:
                print(i)
            for j in range(gta_img_paths_np.shape[0]):
                cur_gta_feat = gta_feat_distrs_np[j, :]
                dist_abs = abs(cur_citys_feat - cur_gta_feat)
                # e_dist = np.sqrt(np.square(dist_abs).sum(axis=1))
                dist_mean = np.sum(dist_abs) / pixel_amount
                distances.append(dist_mean)
            min_loc = np.argsort(distances)
            # need to check overlap
            top_ord = 3
            closest_imgs_loc = min_loc[:top_ord]
            intersect_imgs = np.intersect1d(closest_imgs_loc,
                                            closest_imgs_locs)
            while intersect_imgs.size:
                inters_num = len(intersect_imgs)
                closest_imgs_loc_confirm = np.setdiff1d(
                    closest_imgs_loc, intersect_imgs)  # find the difference
                closest_imgs_loc_candi = min_loc[top_ord:top_ord + inters_num]
                top_ord = top_ord + inters_num
                closest_imgs_loc_confirm = np.concatenate(
                    [closest_imgs_loc_confirm, closest_imgs_loc_candi])
                closest_imgs_loc = closest_imgs_loc_confirm
                intersect_imgs = np.intersect1d(closest_imgs_loc,
                                                closest_imgs_locs)

            closest_imgs_locs.extend(closest_imgs_loc)
        np.save(closest_imgs_path, closest_imgs_locs)
    else:
        closest_imgs_locs = np.load(closest_imgs_path)
    closest_imgs_locs_uni = np.unique(closest_imgs_locs)
    zq = 1

    # get file_names
    with open(src_train_imgs_txt, 'w') as f_train:
        for img_num in closest_imgs_locs_uni:
            line = gta_img_paths_np[img_num] + '\n'
            f_train.write(line)
    def __getitem__(self, index):
        datafiles = self.files[index]

        image = Image.open(datafiles["img"]).convert('RGB')
        name = datafiles["name"]

        # resize
        image = image.resize(self.crop_size, Image.BICUBIC)

        image = np.asarray(image, np.float32)

        size = image.shape
        image = image[:, :, ::-1]  # change to BGR
        image -= self.mean
        image = image.transpose((2, 0, 1))

        return image.copy(), np.array(size), name


if __name__ == '__main__':
    dst = GTA5DataSet("./data", is_transform=True)
    trainloader = data.DataLoader(dst, batch_size=4)
    for i, data in enumerate(trainloader):
        imgs, labels = data
        if i == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = np.transpose(img, (1, 2, 0))
            img = img[:, :, ::-1]
            plt.imshow(img)
            plt.show()
def main():
    """Create the model and start the training."""

    gpu_id_2 = 3
    gpu_id_1 = 2

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(args.restore_from)
            saved_state_dict = torch.load(
                'snapshots/GTA2Cityscapes_multi_54/GTA5_10000.pth')
            model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        # model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = model_D1
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.next()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    # print(args.num_steps * args.iter_size * args.batch_size, trainloader.__len__())

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.next()

    # for i in range(200):
    #     _, batch = targetloader_iter.__next__()
    # exit()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(10000, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            # train with source
            # _, batch = trainloader_iter.next()
            _, batch = trainloader_iter.next()
            _, batch_target = targetloader_iter.next()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size
            # print(loss_seg_1, loss_seg_2)

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # output = pred2.cpu().data[0].numpy()
            # real_lab = labels.cpu().data[0].numpy()
            # output = output.transpose(1,2,0)
            # print(real_lab.shape, output.shape)
            # output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            # output_col = colorize_mask(output)
            # real_lab_col = colorize_mask(real_lab)
            # output = Image.fromarray(output)
            # # name[0].split('/')[-1]
            # # print('result/train_seg_result/' + name[0][len(name[0])-23:len(name[0])-4] + '_color.png')
            # output_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_color.png')
            # real_lab_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_real.png')
            # print(loss_seg_value1, loss_seg_value2)
            # if i_iter == 100:
            #     exit()
            # else:
            #     break

            # train with target

            #_, batch = targetloader_iter.next()
            # images, _, _ = target_batch
            # images_target = Variable(images_target).cuda(gpu_id_2)

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D2(fake1_D)

            loss_adv_fake1 = bce_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = bce_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D2(mix2_D)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_adv_mix1 = bce_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = bce_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D2(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D2(mix2_D_)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            # pred_target1 = pred_target1.detach().cuda(gpu_id_1)
            # pred_target2 = pred_target2.detach().cuda(gpu_id_1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D2(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D2(mix2_D__)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
Пример #9
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg1': loss_seg_value1,
                'loss_seg2': loss_seg_value2,
                'loss_adv_target1': loss_adv_target_value1,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D1': loss_D_value1,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

    if args.tensorboard:
        writer.close()
Пример #10
0
def main():
    """Create the model and start the training."""
    w, h = map(int, args.input_size_source.split(','))
    input_size_source = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        # model.load_state_dict(saved_state_dict)

    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           pretrained=True,
                           vgg16_caffe_path=args.restore_from)

        # saved_state_dict = torch.load(args.restore_from)
        # model.load_state_dict(saved_state_dict)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    model.train()
    model.cuda(args.gpu)
    cudnn.benchmark = True

    #Discrimintator setting
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(args.gpu)

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # labels for adversarial training
    source_adv_label = 0
    target_adv_label = 1

    #Dataloader
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.translated_data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size_source,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    style_trainloader = data.DataLoader(GTA5DataSet(
        args.stylized_data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_source,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

    style_trainloader_iter = enumerate(style_trainloader)

    if STAGE == 1:
        targetloader = data.DataLoader(cityscapesDataSet(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    else:
        #Dataloader for self-training
        targetloader = data.DataLoader(cityscapesDataSetLabel(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set,
            label_folder='Path to generated pseudo labels'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # load checkpoint
    model, model_D, optimizer, start_iter = load_checkpoint(
        model,
        model_D,
        optimizer,
        filename=args.snapshot_dir + 'checkpoint_' + CHECKPOINT + '.pth.tar')

    for i_iter in range(start_iter, args.num_steps):
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        #train segementation network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        if STAGE == 1:
            if i_iter % 2 == 0:
                _, batch = next(trainloader_iter)
            else:
                _, batch = next(style_trainloader_iter)

        else:
            _, batch = next(trainloader_iter)

        image_source, label, _, _ = batch
        image_source = Variable(image_source).cuda(args.gpu)

        pred_source = model(image_source)
        pred_source = interp(pred_source)

        loss_seg_source = loss_calc(pred_source, label, args.gpu)
        loss_seg_source_value = loss_seg_source.item()
        loss_seg_source.backward()

        if STAGE == 2:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, target_label, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #target segmentation loss
            loss_seg_target = loss_calc(pred_target,
                                        target_label,
                                        gpu=args.gpu)
            loss_seg_target.backward()

        # optimize
        optimizer.step()

        if STAGE == 1:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #output-level adversarial training
            D_output_target = model_D(F.softmax(pred_target))
            loss_adv = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_adv = loss_adv * args.lambda_adv
            loss_adv.backward()

            #train discriminator
            for param in model_D.parameters():
                param.requires_grad = True

            pred_source = pred_source.detach()
            pred_target = pred_target.detach()

            D_output_source = model_D(F.softmax(pred_source))
            D_output_target = model_D(F.softmax(pred_target))

            loss_D_source = bce_loss(
                D_output_source,
                Variable(
                    torch.FloatTensor(D_output_source.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_D_target = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        target_adv_label)).cuda(args.gpu))

            loss_D_source = loss_D_source / 2
            loss_D_target = loss_D_target / 2

            loss_D_source.backward()
            loss_D_target.backward()

            #optimize
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg_source = {2:.5f}'.format(
            i_iter, args.num_steps, loss_seg_source_value))

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            state = {
                'iter': i_iter,
                'model': model.state_dict(),
                'model_D': model_D.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                osp.join(args.snapshot_dir,
                         'checkpoint_' + str(i_iter) + '.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_D_' + str(i_iter) + '.pth'))

            cityscapes_eval_dir = osp.join(args.cityscapes_eval_dir,
                                           str(i_iter))
            if not os.path.exists(cityscapes_eval_dir):
                os.makedirs(cityscapes_eval_dir)

            eval(osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'),
                 cityscapes_eval_dir, i_iter)

            iou19, iou13, iou = compute_mIoU(cityscapes_eval_dir, i_iter)
            outputfile = open(args.output_file, 'a')
            outputfile.write(
                str(i_iter) + '\t' + str(iou19) + '\t' +
                str(iou.replace('\n', ' ')) + '\n')
            outputfile.close()
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D = FCDiscriminator(num_classes=2 * args.num_classes).to(device)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D parameters
        restart_from_D = args.restart_from + 'GTA5_{}_D.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D)
        model_D.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    """
    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0
        adv_loss_value = 0
        d_loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate(optimizer_D, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False
            """
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False
            """

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # size (1, 19, 720, 1280)
            # pdb.set_trace()

            # feature = nn.Softmax(dim=1)(pred1)
            # softmax_out = nn.Softmax(dim=1)(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            _, batch = targetloader_iter.__next__()

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=False)

            images, _, _ = batch
            images = images.to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred_target1, pred_target2 = model(images)

            # pred_target1, 2 == [1, 19, 91, 161]
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            # pred_target1, 2 == [1, 19, 720, 1280]
            # pdb.set_trace()

            # feature_target = nn.Softmax(dim=1)(pred_target1)
            # softmax_out_target = nn.Softmax(dim=1)(pred_target2)

            # features = torch.cat((pred1, pred_target1), dim=0)
            # outputs = torch.cat((pred2, pred_target2), dim=0)
            # features.size() == [2, 19, 720, 1280]
            # softmax_out.size() == [2, 19, 720, 1280]
            # pdb.set_trace()
            # transfer_loss = CDAN([features, softmax_out], model_D, None, None, random_layer=None)
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')
            dc_source = torch.FloatTensor(
                D_out_target.size()).fill_(0).to(device)
            # pdb.set_trace()
            adv_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_source)
            adv_loss = adv_loss / args.iter_size
            adv_loss = args.lambda_adv * adv_loss
            # pdb.set_trace()
            # classifier_loss = nn.BCEWithLogitsLoss()(pred2,
            #        torch.FloatTensor(pred2.data.size()).fill_(source_label).cuda())
            # pdb.set_trace()
            adv_loss.backward()
            adv_loss_value += adv_loss.item()
            # optimizer_D.step()
            #TODO: normalize loss?

            for params in model_D.parameters():
                params.requires_grad_(requires_grad=True)

            pred1 = pred1.detach()
            pred2 = pred2.detach()
            D_out = CDAN([F.softmax(pred1), F.softmax(pred2)],
                         model_D,
                         cdan_implement='concat')

            dc_source = torch.FloatTensor(D_out.size()).fill_(0).to(device)
            # d_loss = CDAN(D_out, dc_source, None, None, random_layer=None)
            d_loss = nn.BCEWithLogitsLoss()(D_out, dc_source)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()
            D_out_target = CDAN(
                [F.softmax(pred_target1),
                 F.softmax(pred_target2)],
                model_D,
                cdan_implement='concat')

            dc_target = torch.FloatTensor(
                D_out_target.size()).fill_(1).to(device)
            d_loss = nn.BCEWithLogitsLoss()(D_out_target, dc_target)
            d_loss = d_loss / args.iter_size
            # pdb.set_trace()
            d_loss.backward()
            d_loss_value += d_loss.item()

            continue

        optimizer.step()
        optimizer_D.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'generator_loss': adv_loss_value,
            'discriminator_loss': d_loss_value,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} generator = {4:.3f}, discriminator = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    adv_loss_value, d_loss_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)
            save_path = args.snapshot_dir + '/eval_{}'.format(i_iter)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            # evaluate(args, save_path, args.snapshot_dir, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
                i_iter)
            pdb.set_trace()
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
def main():

    city = np.load("dump_cityscape5.npy")
    gta = np.load("dump_gta5.npy")

    city_scape = city[:1000, :]
    gta5 = gta[:1000, :]

    combined = np.concatenate((city[1000:, :], gta[1000:, :]))

    np.save('source.npy', gta5)
    np.save('target.npy', city_scape)
    np.save('mixed.npy', combined)
    print(city_scape.shape)
    print(gta5.shape)
    print(combined.shape)
    exit()

    print(type(dump))
    print(dump.shape)

    b = dump
    print(type(dump))
    print(dump.shape)

    import random

    a = np.stack(random.sample(a, 500))
    b = np.stack(random.sample(b, 500))

    dump = np.concatenate((a, b))
    print(dump.shape)

    arr = np.arange(10)
    #print(dump)
    exit()
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    model = Res_Deeplab(num_classes=args.num_classes)
    #model = getVGG(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    #trainloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), batch_size=1, shuffle=False, pin_memory=True)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(1024, 512),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False),
                                  batch_size=1,
                                  shuffle=False,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    interp = nn.Upsample(size=(1024, 2048), mode='bilinear')
    dump_array = np.array((1, 2))

    for itr in xrange(2000):
        print(itr)
        _, batch = trainloader_iter.next()
        images, labels, _, _ = batch
        #images, _, _ = batch

        output1, output2 = model(Variable(images, volatile=True).cuda(gpu0))
        import torch.nn.functional as F
        output2 = F.avg_pool2d(output2, (4, 4))
        output2 = output2.data.cpu().numpy()
        output2 = np.reshape(output2, (1, -1))

        if dump_array.shape == (2, ):
            dump_array = output2
        else:
            dump_array = np.concatenate((dump_array, output2))

    np.save('dump_gta5.npy', dump_array)
Пример #13
0
def main():
    """Create the model and start the training."""
    setup_seed(666)
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    model = DeeplabMulti(num_classes=args.num_classes)
    saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            if i_parts[1]=='layer4' and i_parts[2]=='2':
                i_parts[1] = 'layer5'
                i_parts[2] = '0'
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            else:
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    model.load_state_dict(new_params)
    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    num_class_list = [2048, 19]
    model_D = nn.ModuleList([FCDiscriminator(num_classes=num_class_list[i]).train().to(device) if i<1 else OutspaceDiscriminator(num_classes=num_class_list[i]).train().to(device) for i in range(2)])
    
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feat_source, pred_source = model(images, model_D, 'source')
        pred_source = interp(pred_source)

        loss_seg = seg_loss(pred_source, labels)
        loss_seg.backward()

        # train with target
        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feat_target, pred_target = model(images, model_D, 'target')
        pred_target = interp_target(pred_target)

        loss_adv = 0
        D_out = model_D[0](feat_target)
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        D_out = model_D[1](F.softmax(pred_target, dim=1))
        loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
        loss_adv = loss_adv*0.01
        loss_adv.backward()
        
        optimizer.step()

        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        loss_D_source = 0
        D_out_source = model_D[0](feat_source.detach())
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        D_out_source = model_D[1](F.softmax(pred_source.detach(),dim=1))
        loss_D_source += bce_loss(D_out_source, torch.FloatTensor(D_out_source.data.size()).fill_(source_label).to(device))
        loss_D_source.backward()

        # train with target
        loss_D_target = 0
        D_out_target = model_D[0](feat_target.detach())
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        D_out_target = model_D[1](F.softmax(pred_target.detach(),dim=1))
        loss_D_target += bce_loss(D_out_target, torch.FloatTensor(D_out_target.data.size()).fill_(target_label).to(device))
        loss_D_target.backward()
        
        optimizer_D.step()

        if i_iter % 10 == 0:
            print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D_s = {4:.3f}, loss_D_t = {5:.3f}'.format(
            i_iter, args.num_steps, loss_seg.item(), loss_adv.item(), loss_D_source.item(), loss_D_target.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    def _forward_single(_iter):
        _scalar = torch.tensor(1.).cuda().requires_grad_()

        _, batch = _iter.__next__()
        images, labels, _, _ = batch
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model.forward_irm(images, _scalar)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        # compute penalty
        grad = autograd.grad(loss, [_scalar], create_graph=True)[0]
        penalty = torch.sum(grad**2)

        loss += args.lambda_irm * penalty

        loss.backward()

        return loss_seg1, loss_seg2

    for i_iter in range(args.num_steps):

        loss_seg1_src = 0
        loss_seg2_src = 0

        loss_seg1_tgt = 0
        loss_seg2_tgt = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):

            # def _forward_single(_loader, _model)
            #     _, batch = _loader_iter.next()
            #     images, labels, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred1, pred2 = _model(images)
            #     pred1 = interp(pred1)
            #     pred2 = interp(pred2)

            #     loss_seg1 = loss_calc(pred1, labels, args.gpu)
            #     loss_seg2 = loss_calc(pred2, labels, args.gpu)
            #     loss = loss_seg2 + args.lambda_seg * loss_seg1

            #     # proper normalization
            #     loss = loss / args.iter_size
            #     loss.backward()

            #     return loss_seg1, loss_seg2

            # _, batch = targetloader_iter.next()
            # images, _, _ = batch
            # images = Variable(images).cuda(args.gpu)

            # pred1_tgt, pred2_tgt = model(images)
            # pred1_tgt = interp_target(pred1_tgt)
            # pred2_tgt = interp_target(pred2_tgt)

            # loss_seg1 = loss_calc(pred1, labels, args.gpu)
            # loss_seg2 = loss_calc(pred2, labels, args.gpu)
            # loss = loss_seg2 + args.lambda_seg * loss_seg1
            # loss = loss / args.iter_size
            # loss.backward()

            # train with source
            loss_seg1, loss_seg2 = _forward_single(trainloader_iter)
            loss_seg1_src += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_src += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target
            loss_seg1, loss_seg2 = _forward_single(targetloader_iter)
            loss_seg1_tgt += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_tgt += loss_seg2.data.cpu().numpy()[0] / args.iter_size

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
Пример #15
0
def main():
    """Create the model and start the training."""

    cudnn.enabled = True
    cudnn.benchmark = True

    device = torch.device("cuda" if not args.cpu else "cpu")

    random.seed(args.random_seed)

    snapshot_dir = os.path.join(args.snapshot_dir, args.experiment)
    log_dir = os.path.join(args.log_dir, args.experiment)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(snapshot_dir, exist_ok=True)

    log_file = os.path.join(log_dir, 'log.txt')

    init_log(log_file, args)

    # =============================================================================
    # INIT G
    # =============================================================================
    if MODEL == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes,
                            restore_from=args.restore_from)
    model.train()
    model.to(device)

    # =============================================================================
    # INIT D
    # =============================================================================

    model_D = FCDiscriminator(num_classes=args.num_classes)

    # saved_state_dict_D = torch.load(RESTORE_FROM_D) #for retrain
    # model_D.load_state_dict(saved_state_dict_D)

    model_D.train()
    model_D.to(device)

    # DataLoaders
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=args.input_size_source,
                                              scale=True,
                                              mirror=True,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=args.input_size_target,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    # Optimizers
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # Losses
    bce_loss = torch.nn.BCEWithLogitsLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    interp_source = nn.Upsample(size=(args.input_size_source[1],
                                      args.input_size_source[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_target = nn.Upsample(size=(args.input_size_target[1],
                                      args.input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # Labels for Adversarial Training
    source_label = 0
    target_label = 1

    # ======================================================================================
    # Start training
    # ======================================================================================
    print('###########   TRAINING STARTED  ############')
    start = time.time()

    for i_iter in range(args.start_from_iter, args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        damping = (1 - (i_iter) / NUM_STEPS)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Remove Grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s, _, _ = batch
        images_s = images_s.to(device)
        pred_source1, pred_source2 = model(images_s)

        pred_source1 = interp_source(pred_source1)
        pred_source2 = interp_source(pred_source2)

        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, device) +
                    loss_calc(pred_source2, labels_s, device))
        loss_seg.backward()

        # Train with Target
        _, batch = next(targetloader_iter)
        images_t, _, _ = batch
        images_t = images_t.to(device)

        pred_target1, pred_target2 = model(images_t)

        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        weight_map = weightmap(F.softmax(pred_target1, dim=1),
                               F.softmax(pred_target2, dim=1))

        D_out = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if i_iter > PREHEAT_STEPS:
            loss_adv = weighted_bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_adv = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))

        loss_adv = loss_adv * Lambda_adv * damping
        loss_adv.backward()

        # Weight Discrepancy Loss
        W5 = None
        W6 = None
        if args.model == 'ResNet':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        loss_weight = (torch.matmul(W5, W6) /
                       (torch.norm(W5) * torch.norm(W6)) + 1
                       )  # +1 is for a positive loss
        loss_weight = loss_weight * Lambda_weight * damping * 2
        loss_weight.backward()

        # ======================================================================================
        # train D
        # ======================================================================================

        # Bring back Grads in D
        for param in model_D.parameters():
            param.requires_grad = True

        # Train with Source
        pred_source1 = pred_source1.detach()
        pred_source2 = pred_source2.detach()

        D_out_s = interp_source(
            model_D(F.softmax(pred_source1 + pred_source2, dim=1)))

        loss_D_s = bce_loss(
            D_out_s,
            torch.FloatTensor(
                D_out_s.data.size()).fill_(source_label).to(device))

        loss_D_s.backward()

        # Train with Target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()
        weight_map = weight_map.detach()

        D_out_t = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_D_t = weighted_bce_loss(
                D_out_t,
                torch.FloatTensor(
                    D_out_t.data.size()).fill_(target_label).to(device),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_D_t = bce_loss(
                D_out_t,
                torch.FloatTensor(
                    D_out_t.data.size()).fill_(target_label).to(device))

        loss_D_t.backward()

        optimizer.step()
        optimizer_D.step()

        if (i_iter) % 10 == 0:
            log_message(
                'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_adv = {3:.4f}, loss_weight = {4:.4f}, loss_D_s = {5:.4f} loss_D_t = {6:.4f}'
                .format(i_iter, args.num_steps, loss_seg, loss_adv,
                        loss_weight, loss_D_s, loss_D_t), log_file)

        if (i_iter % args.save_pred_every == 0
                and i_iter != 0) or i_iter == args.num_steps - 1:
            i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1  # for last iter
            print('saving weights...')
            torch.save(model.state_dict(),
                       osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    end = time.time()
    log_message(
        'Total training time: {} days, {} hours, {} min, {} sec '.format(
            int((end - start) / 86400), int((end - start) / 3600),
            int((end - start) / 60 % 60), int((end - start) % 60)), log_file)
    print('### Experiment: ' + args.experiment + ' finished ###')
Пример #16
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")
    cudnn.benchmark = True
    cudnn.enabled = True

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    Iter = 0
    bestIoU = 0

    # Create network
    # init G
    if args.model == 'DeepLab':
        model = DeeplabMultiFeature(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        if args.continue_train:
            if list(saved_state_dict.keys())[0].split('.')[0] == 'module':
                for key in saved_state_dict.keys():
                    saved_state_dict['.'.join(
                        key.split('.')[1:])] = saved_state_dict.pop(key)
            model.load_state_dict(saved_state_dict)
        else:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
            model.load_state_dict(new_params)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes).to(device)

    if args.continue_train:
        model_weights_path = args.restore_from
        temp = model_weights_path.split('.')
        temp[-2] = temp[-2] + '_D'
        model_D_weights_path = '.'.join(temp)
        model_D.load_state_dict(torch.load(model_D_weights_path))
        temp = model_weights_path.split('.')
        temp = temp[-2][-9:]
        Iter = int(temp.split('_')[1]) + 1

    model.train()
    model.to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # init data loader
    if args.data_dir.split('/')[-1] == 'gta5_deeplab':
        trainset = GTA5DataSet(args.data_dir,
                               args.data_list,
                               max_iters=args.num_steps * args.iter_size *
                               args.batch_size,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)
    elif args.data_dir.split('/')[-1] == 'syn_deeplab':
        trainset = synthiaDataSet(args.data_dir,
                                  args.data_list,
                                  max_iters=args.num_steps * args.iter_size *
                                  args.batch_size,
                                  crop_size=input_size,
                                  scale=args.random_scale,
                                  mirror=args.random_mirror,
                                  mean=IMG_MEAN)

    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    # init optimizer
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level="O2",
                                      keep_batchnorm_fp32=True,
                                      loss_scale="dynamic")

    model_D, optimizer_D = amp.initialize(model_D,
                                          optimizer_D,
                                          opt_level="O2",
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")

    # init loss
    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)
    L1_loss = torch.nn.L1Loss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    test_interp = nn.Upsample(size=(1024, 2048),
                              mode='bilinear',
                              align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # init prototype
    num_prototype = args.num_prototype
    num_ins = args.num_prototype * 10
    src_cls_features = torch.zeros([len(BG_LABEL), num_prototype, 2048],
                                   dtype=torch.float32).to(device)
    src_cls_ptr = np.zeros(len(BG_LABEL), dtype=np.uint64)
    src_ins_features = torch.zeros([len(FG_LABEL), num_ins, 2048],
                                   dtype=torch.float32).to(device)
    src_ins_ptr = np.zeros(len(FG_LABEL), dtype=np.uint64)

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        writer = SummaryWriter(args.log_dir)

    # start training
    for i_iter in range(Iter, args.num_steps):

        loss_seg_value = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cls_value = 0
        loss_ins_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D

            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            src_feature, pred = model(images)
            pred_softmax = F.softmax(pred, dim=1)
            pred_idx = torch.argmax(pred_softmax, dim=1)

            right_label = F.interpolate(labels.unsqueeze(0).float(),
                                        (pred_idx.size(1), pred_idx.size(2)),
                                        mode='nearest').squeeze(0).long()
            right_label[right_label != pred_idx] = 255

            for ii in range(len(BG_LABEL)):
                cls_idx = BG_LABEL[ii]
                mask = right_label == cls_idx
                if torch.sum(mask) == 0:
                    continue
                feature = global_avg_pool(src_feature, mask.float())
                if cls_idx != torch.argmax(
                        torch.squeeze(model.layer6(
                            feature.half()).float())).item():
                    continue
                src_cls_features[ii,
                                 int(src_cls_ptr[ii] %
                                     num_prototype), :] = torch.squeeze(
                                         feature).clone().detach()
                src_cls_ptr[ii] += 1

            seg_ins = seg_label(right_label.squeeze())
            for ii in range(len(FG_LABEL)):
                cls_idx = FG_LABEL[ii]
                segmask, pixelnum = seg_ins[ii]
                if len(pixelnum) == 0:
                    continue
                sortmax = np.argsort(pixelnum)[::-1]
                for i in range(min(10, len(sortmax))):
                    mask = segmask == (sortmax[i] + 1)
                    feature = global_avg_pool(src_feature, mask.float())
                    if cls_idx != torch.argmax(
                            torch.squeeze(
                                model.layer6(feature.half()).float())).item():
                        continue
                    src_ins_features[ii, int(src_ins_ptr[ii] %
                                             num_ins), :] = torch.squeeze(
                                                 feature).clone().detach()
                    src_ins_ptr[ii] += 1

            pred = interp(pred)
            loss_seg = seg_loss(pred, labels)
            loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_seg_value += loss_seg.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)

            trg_feature, pred_target = model(images)

            pred_target_softmax = F.softmax(pred_target, dim=1)
            pred_target_idx = torch.argmax(pred_target_softmax, dim=1)

            loss_cls = torch.zeros(1).to(device)
            loss_ins = torch.zeros(1).to(device)
            if i_iter > 0:
                for ii in range(len(BG_LABEL)):
                    cls_idx = BG_LABEL[ii]
                    if src_cls_ptr[ii] / num_prototype <= 1:
                        continue
                    mask = pred_target_idx == cls_idx
                    feature = global_avg_pool(trg_feature, mask.float())
                    if cls_idx != torch.argmax(
                            torch.squeeze(
                                model.layer6(feature.half()).float())).item():
                        continue
                    ext_feature = feature.squeeze().expand(num_prototype, 2048)
                    loss_cls += torch.min(
                        torch.sum(L1_loss(ext_feature,
                                          src_cls_features[ii, :, :]),
                                  dim=1) / 2048.)

                seg_ins = seg_label(pred_target_idx.squeeze())
                for ii in range(len(FG_LABEL)):
                    cls_idx = FG_LABEL[ii]
                    if src_ins_ptr[ii] / num_ins <= 1:
                        continue
                    segmask, pixelnum = seg_ins[ii]
                    if len(pixelnum) == 0:
                        continue
                    sortmax = np.argsort(pixelnum)[::-1]
                    for i in range(min(10, len(sortmax))):
                        mask = segmask == (sortmax[i] + 1)
                        feature = global_avg_pool(trg_feature, mask.float())
                        feature = feature.squeeze().expand(num_ins, 2048)
                        loss_ins += torch.min(
                            torch.sum(L1_loss(feature,
                                              src_ins_features[ii, :, :]),
                                      dim=1) / 2048.) / min(10, len(sortmax))

            pred_target = interp_target(pred_target)

            D_out = model_D(F.softmax(pred_target, dim=1))
            loss_adv_target = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target * loss_adv_target + args.lambda_adv_cls * loss_cls + args.lambda_adv_ins * loss_ins
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_adv_target_value += loss_adv_target.item() / args.iter_size

            # train D

            # bring back requires_grad

            for param in model_D.parameters():
                param.requires_grad = True

            # train with source
            pred = pred.detach()
            D_out = model_D(F.softmax(pred, dim=1))

            loss_D = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))
            loss_D = loss_D / args.iter_size / 2
            amp_backward(loss_D, optimizer_D)
            loss_D_value += loss_D.item()

            # train with target
            pred_target = pred_target.detach()
            D_out = model_D(F.softmax(pred_target, dim=1))

            loss_D = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(target_label).to(device))
            loss_D = loss_D / args.iter_size / 2
            amp_backward(loss_D, optimizer_D)
            loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv = {3:.3f} loss_D = {4:.3f} loss_cls = {5:.3f} loss_ins = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_target_value, loss_D_value, loss_cls.item(),
                    loss_ins.item()))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            testloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target_test,
                crop_size=(1024, 512),
                mean=IMG_MEAN,
                scale=False,
                mirror=False,
                set='val'),
                                         batch_size=1,
                                         shuffle=False,
                                         pin_memory=True)
            model.eval()
            for index, batch in enumerate(testloader):
                image, _, name = batch
                with torch.no_grad():
                    output1, output2 = model(Variable(image).to(device))
                output = test_interp(output2).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output = Image.fromarray(output)
                name = name[0].split('/')[-1]
                output.save('%s/%s' % (args.save, name))
            mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'),
                                 args.save, 'dataset/cityscapes_list')
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            if mIoU > bestIoU:
                bestIoU = mIoU
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5.pth'))
                torch.save(model_D.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5_D.pth'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
            model.train()

    if args.tensorboard:
        writer.close()
Пример #17
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    #### restore model_D1, D2 and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

        # model_D1 parameters
        restart_from_D1 = args.restart_from + 'GTA5_{}_D1.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D1)
        model_D1.load_state_dict(saved_state_dict)

        # model_D2 parameters
        restart_from_D2 = args.restart_from + 'GTA5_{}_D2.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_D2)
        model_D2.load_state_dict(saved_state_dict)

    #### model_D1, D2 are randomly initialized, model is pre-trained ResNet on ImageNet
    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    model_D1.train()
    model_D1.to(device)

    model_D2.train()
    model_D2.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            pdb.set_trace()
            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = images.to(device)
            pdb.set_trace()
            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)
            pdb.set_trace()
            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.item()
            loss_D_value2 += loss_D2.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
            'loss_adv_target1': loss_adv_target_value1,
            'loss_adv_target2': loss_adv_target_value2,
            'loss_D1': loss_D_value1,
            'loss_D2': loss_D_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['learning_rate_D'] = optimizer_D1.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
Пример #18
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMultiFeature(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D2.train()
    model_D2.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    cityset = cityscapesDataSet(args.data_dir_target,
                                args.data_list_target,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size_target,
                                scale=False,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN,
                                set=args.set)
    targetloader = data.DataLoader(cityset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    # init cls D
    model_clsD = []
    optimizer_clsD = []
    for i in range(args.num_classes):
        model_temp = FCDiscriminatorCLS(
            num_classes=args.num_classes).to(device).train()
        optimizer_temp = optim.Adam(model_temp.parameters(),
                                    lr=args.learning_rate_D,
                                    betas=(0.9, 0.99))
        optimizer_temp.zero_grad()
        #model_temp, optimizer_temp = amp.initialize(
        #    model_temp, optimizer_temp, opt_level="O1",
        #    keep_batchnorm_fp32=None, loss_scale="dynamic"
        #)
        model_temp, optimizer_temp = amp.initialize(model_temp,
                                                    optimizer_temp,
                                                    opt_level="O1",
                                                    keep_batchnorm_fp32=None,
                                                    loss_scale="dynamic")
        model_clsD.append(model_temp)
        optimizer_clsD.append(optimizer_temp)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level="O1",
                                      keep_batchnorm_fp32=None,
                                      loss_scale="dynamic")

    model_D2, optimizer_D2 = amp.initialize(model_D2,
                                            optimizer_D2,
                                            opt_level="O1",
                                            keep_batchnorm_fp32=None,
                                            loss_scale="dynamic")

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    cls_begin_iter = 10000

    num_target_imgs = 2975
    predicted_label = np.zeros(
        (num_target_imgs, 1, input_size_target[1], input_size_target[0]),
        dtype=np.uint8)
    predicted_prob = np.zeros(
        (num_target_imgs, 1, input_size_target[1], input_size_target[0]),
        dtype=np.float16)
    name2idxmap = {}
    for i in range(num_target_imgs):
        name2idxmap[cityset.files[i]['name']] = i
    thres = []

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_cls_adv = 0
        loss_cls_adv_value = 0
        loss_cls_D = 0
        loss_cls_D_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        if i_iter >= cls_begin_iter:
            for i in range(args.num_classes):
                optimizer_clsD[i].zero_grad()
                adjust_learning_rate_D(optimizer_clsD[i], i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            if i_iter >= cls_begin_iter:
                for i in range(args.num_classes):
                    for param in model_clsD[i].parameters():
                        param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            _, pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_seg_value2 += loss_seg2.item() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, name = batch
            images = images.to(device)
            name = name[0]
            img_idx = name2idxmap[name]

            _, pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            pred_target_score = F.softmax(pred_target2, dim=1)
            D_out2 = model_D2(pred_target_score)
            loss_adv_target2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))

            target_pred_prob, target_pred_cls = torch.max(pred_target_score,
                                                          dim=1)
            predicted_label[img_idx,
                            ...] = target_pred_cls.cpu().data.numpy().astype(
                                np.uint8)
            predicted_prob[img_idx,
                           ...] = target_pred_prob.cpu().data.numpy().astype(
                               np.float16)

            if i_iter >= cls_begin_iter and i_iter % 5000 == 0:
                thres = []
                for i in range(args.num_classes):
                    x = predicted_prob[predicted_label == i]
                    if len(x) == 0:
                        thres.append(0)
                        continue
                    x = np.sort(x)
                    thres.append(x[np.int(np.round(len(x) * 0.5))])
                print(thres)
                thres = np.array(thres)
                thres[thres > 0.9] = 0.9
                np.save(osp.join(args.snapshot_dir, 'predicted_label'),
                        predicted_label)
                np.save(osp.join(args.snapshot_dir, 'predicted_prob'),
                        predicted_prob)

            if i_iter >= cls_begin_iter:
                target_pred_cls = target_pred_cls.long().detach()
                for i in range(args.num_classes):
                    cls_mask = (target_pred_cls
                                == i) * (target_pred_cls >= thres[i])
                    if torch.sum(cls_mask) == 0:
                        continue
                    cls_gt = torch.tensor(
                        target_pred_cls.data).long().to(device)
                    cls_gt[~cls_mask] = 255
                    cls_gt[cls_mask] = source_label
                    cls_out = model_clsD[i](pred_target_score)
                    loss_cls_adv += seg_loss(cls_out, cls_gt)
                loss_cls_adv_value = loss_cls_adv.item() / args.iter_size

            loss = args.lambda_adv_target2 * loss_adv_target2 + LAMBDA_CLS_ADV * loss_cls_adv
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size

            # train D

            # bring back requires_grad

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label).to(device))
            loss_D2 = loss_D2 / args.iter_size / 2
            amp_backward(loss_D2, optimizer_D2)
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target2 = pred_target2.detach()
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D2 = bce_loss(
                D_out2,
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label).to(device))
            loss_D2 = loss_D2 / args.iter_size / 2
            amp_backward(loss_D2, optimizer_D2)
            loss_D_value2 += loss_D2.item()

            if i_iter >= cls_begin_iter:
                for i in range(args.num_classes):
                    for param in model_clsD[i].parameters():
                        param.requires_grad = True

                pred_source_score = F.softmax(pred2, dim=1)
                source_pred_prob, source_pred_cls = torch.max(
                    pred_source_score, dim=1)
                source_pred_cls = source_pred_cls.long().detach()
                for i in range(args.num_classes):
                    cls_mask = (source_pred_cls == i) * (labels == i)
                    if torch.sum(cls_mask) == 0:
                        continue
                    cls_gt = torch.tensor(
                        source_pred_cls.data).long().to(device)
                    cls_gt[~cls_mask] = 255
                    cls_gt[cls_mask] = source_label
                    cls_out = model_clsD[i](pred_source_score)
                    loss_cls_D = seg_loss(cls_out, cls_gt) / 2
                    amp_backward(loss_cls_D, optimizer_clsD[i])
                    loss_cls_D_value += loss_cls_D.item()

                pred_target_score = F.softmax(pred_target2, dim=1)
                target_pred_prob, target_pred_cls = torch.max(
                    pred_target_score, dim=1)
                target_pred_cls = target_pred_cls.long().detach()
                for i in range(args.num_classes):
                    cls_mask = (target_pred_cls
                                == i) * (target_pred_cls >= thres[i])
                    if torch.sum(cls_mask) == 0:
                        continue
                    cls_gt = torch.tensor(
                        target_pred_cls.data).long().to(device)
                    cls_gt[~cls_mask] = 255
                    cls_gt[cls_mask] = target_label
                    cls_out = model_clsD[i](pred_target_score)
                    loss_cls_adv += seg_loss(cls_out, cls_gt)
                    loss_cls_D = seg_loss(cls_out, cls_gt) / 2
                    amp_backward(loss_cls_D, optimizer_clsD[i])
                    loss_cls_D_value += loss_cls_D.item()

        optimizer.step()
        optimizer_D2.step()
        if i_iter >= cls_begin_iter:
            for i in range(args.num_classes):
                optimizer_clsD[i].step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg2': loss_seg_value2,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg2 = {2:.3f}, loss_adv2 = {3:.3f} loss_D2 = {4:.3f} loss_cls_adv = {5:.3f} loss_cls_D = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value2,
                    loss_adv_target_value2, loss_D_value2, loss_cls_adv_value,
                    loss_cls_D_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            for i in range(args.num_classes):
                torch.save(
                    model_clsD[i].state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(NUM_STEPS) + '_clsD.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            for i in range(args.num_classes):
                torch.save(
                    model_clsD[i].state_dict(),
                    osp.join(args.snapshot_dir, 'GTA5_clsD' + str(i) + '.pth'))

    if args.tensorboard:
        writer.close()
Пример #19
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)

    #### restore model_D and model
    if RESTART:
        # pdb.set_trace()
        # model parameters
        restart_from_model = args.restart_from + 'GTA5_{}.pth'.format(
            RESTART_ITER)
        saved_state_dict = torch.load(restart_from_model)
        model.load_state_dict(saved_state_dict)

    else:
        # model parameters
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):
        # pdb.set_trace()
        loss_seg_value1 = 0
        loss_seg_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        """
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)
        """

        for sub_i in range(args.iter_size):

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            # pdb.set_trace()
            # images.size() == [1, 3, 720, 1280]
            pred1, pred2 = model(images)
            # pred1, pred2 size == [1, 19, 91, 161]
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = seg_loss(pred1, labels)
            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            # pdb.set_trace()
            # proper normalization
            loss = loss / args.iter_size
            # TODO: uncomment
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            # pdb.set_trace()
            # train with target

            continue

        optimizer.step()

        scalar_info = {
            'loss_seg1': loss_seg_value1,
            'loss_seg2': loss_seg_value2,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)
        # pdb.set_trace()
        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_noadapt_' + str(args.num_steps_stop) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

            # check_original_discriminator(args, pred_target1, pred_target2, i_iter)

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
Пример #20
0
def main():
    """Create the model and start the training."""

    cudnn.enabled = True
    cudnn.benchmark = True

    device = torch.device("cuda" if not args.cpu else "cpu")

    snapshot_dir = os.path.join(args.snapshot_dir, args.experiment)
    os.makedirs(snapshot_dir, exist_ok=True)

    log_file = os.path.join(args.log_dir, '%s.txt' % args.experiment)
    init_log(log_file, args)

    # =============================================================================
    # INIT G
    # =============================================================================
    if MODEL == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes,
                            restore_from=args.restore_from)
    model.train()
    model.to(device)

    # DataLoaders
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=args.input_size_source,
                                              scale=True,
                                              mirror=True,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    #   trainloader = data.DataLoader(cityscapesDataSetLabel(args.data_dir_target, './dataset/cityscapes_list/info.json', args.data_list_target,args.data_list_label_target,
    #                                                   max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=args.input_size_target,
    #                                                   mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True,num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # Optimizers
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(args.input_size_source[1],
                               args.input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    #interp = nn.Upsample(size=(args.input_size_target[1], args.input_size_target[0]), mode='bilinear', align_corners=True)

    # ======================================================================================
    # Start training
    # ======================================================================================
    log_message('###########   TRAINING STARTED  ############', log_file)
    start = time.time()

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s = batch
        images_s = images_s.to(device)

        pred_source1, pred_source2 = model(images_s)

        pred_source1 = interp(pred_source1)
        pred_source2 = interp(pred_source2)
        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, device) +
                    loss_calc(pred_source2, labels_s, device))
        loss_seg.backward()

        optimizer.step()

        if i_iter % 10 == 0:
            log_message(
                'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f}'.format(
                    i_iter, args.num_steps - 1, loss_seg), log_file)

        if (i_iter % args.save_pred_every == 0
                and i_iter != 0) or i_iter == args.num_steps - 1:
            print('saving weights...')
            torch.save(model.state_dict(),
                       osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

    end = time.time()
    log_message(
        'Total training time: {} days, {} hours, {} min, {} sec '.format(
            int((end - start) / 86400), int((end - start) / 3600),
            int((end - start) / 60 % 60), int((end - start) % 60)), log_file)
    print('### Experiment: ' + args.experiment + ' Finished ###')
Пример #21
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size_source.split(','))
    input_size_source = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True

    # Create Network
    model = Res_Deeplab(num_classes=args.num_classes)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]

    if args.restore_from[:4] == './mo':
        model.load_state_dict(new_params)
    else:
        model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # Init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    # =============================================================================
    #    #for retrain
    #    saved_state_dict_D = torch.load(RESTORE_FROM_D)
    #    model_D.load_state_dict(saved_state_dict_D)
    # =============================================================================

    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    if args.source == 'GTA5':
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_source,
            scale=True,
            mirror=True,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
    else:
        trainloader = data.DataLoader(SYNTHIADataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_source,
            scale=True,
            mirror=True,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    interp_source = nn.Upsample(size=(input_size_source[1],
                                      input_size_source[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # Labels for Adversarial Training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        damping = (1 - i_iter / NUM_STEPS)

        #======================================================================================
        # train G
        #======================================================================================

        #Remove Grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s, _, _, _ = batch
        images_s = Variable(images_s).cuda(args.gpu)
        pred_source1, pred_source2 = model(images_s)
        pred_source1 = interp_source(pred_source1)
        pred_source2 = interp_source(pred_source2)

        #Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, args.gpu) +
                    loss_calc(pred_source2, labels_s, args.gpu))
        loss_seg.backward()

        # Train with Target
        _, batch = next(targetloader_iter)
        images_t, _, _, _ = batch
        images_t = Variable(images_t).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_t)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        weight_map = weightmap(F.softmax(pred_target1, dim=1),
                               F.softmax(pred_target2, dim=1))

        D_out = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        #Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_adv = weighted_bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda(args.gpu),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_adv = bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv = loss_adv * Lambda_adv * damping
        loss_adv.backward()

        #Weight Discrepancy Loss
        W5 = None
        W6 = None
        if args.model == 'ResNet':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        loss_weight = (torch.matmul(W5, W6) /
                       (torch.norm(W5) * torch.norm(W6)) + 1
                       )  # +1 is for a positive loss
        loss_weight = loss_weight * Lambda_weight * damping * 2
        loss_weight.backward()

        #======================================================================================
        # train D
        #======================================================================================

        # Bring back Grads in D
        for param in model_D.parameters():
            param.requires_grad = True

        # Train with Source
        pred_source1 = pred_source1.detach()
        pred_source2 = pred_source2.detach()

        D_out_s = interp_source(
            model_D(F.softmax(pred_source1 + pred_source2, dim=1)))

        loss_D_s = bce_loss(
            D_out_s,
            Variable(
                torch.FloatTensor(
                    D_out_s.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D_s.backward()

        # Train with Target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()
        weight_map = weight_map.detach()

        D_out_t = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        #Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_D_t = weighted_bce_loss(
                D_out_t,
                Variable(
                    torch.FloatTensor(
                        D_out_t.data.size()).fill_(target_label)).cuda(
                            args.gpu), weight_map, Epsilon, Lambda_local)
        else:
            loss_D_t = bce_loss(
                D_out_t,
                Variable(
                    torch.FloatTensor(
                        D_out_t.data.size()).fill_(target_label)).cuda(
                            args.gpu))

        loss_D_t.backward()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_adv = {3:.4f}, loss_weight = {4:.4f}, loss_D_s = {5:.4f} loss_D_t = {6:.4f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv, loss_weight,
                    loss_D_s, loss_D_t))

        f_loss = open(osp.join(args.snapshot_dir, 'loss.txt'), 'a')
        f_loss.write('{0:.4f} {1:.4f} {2:.4f} {3:.4f} {4:.4f}\n'.format(
            loss_seg, loss_adv, loss_weight, loss_D_s, loss_D_t))
        f_loss.close()

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
Пример #22
0
def main():
    """Create the model and start the training."""
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    mkdirs(osp.join("logs/"+args.exp_name))

    logger = create_logger('global_logger', "logs/" + args.exp_name + '/log.txt')
    logger.info('{}'.format(args))
##############################

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))
    logger.info("random_scale {}".format(args.random_scale))
    logger.info("is_training {}".format(args.is_training))

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)
    print(type(input_size_target[1]))
    cudnn.enabled = True
    args.snapshot_dir = args.snapshot_dir + args.exp_name
    tb_logger = SummaryWriter("logs/"+args.exp_name)
##############################

#validation data
    h, w = map(int, args.input_size_test.split(','))
    input_size_test = (h,w)
    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)
    h, w = map(int, args.input_size_crop.split(','))
    input_size_crop = h,w
    h,w = map(int, args.input_size_target_crop.split(','))
    input_size_target_crop = h,w


    test_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std=[0.229, 0.224, 0.225])
    test_transform = transforms.Compose([
                         transforms.Resize((input_size_test[1], input_size_test[0])),
                         transforms.ToTensor(),
                         test_normalize])

    valloader = data.DataLoader(cityscapesDataSet(
                                       args.data_dir_target,
                                       args.data_list_target_val,
                                       crop_size=input_size_test,
                                       set='train',
                                       transform=test_transform),num_workers=args.num_workers,
                                 batch_size=1, shuffle=False, pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list_val = args.label_path_list_val
    label_path_list_test = args.label_path_list_test
    label_path_list_test = './dataset/cityscapes_list/label.txt'
    gt_imgs_val = open(label_path_list_val, 'r').read().splitlines()
    gt_imgs_val = [osp.join(args.data_dir_target_val, x) for x in gt_imgs_val]
    testloader = data.DataLoader(cityscapesDataSet(
                                    args.data_dir_target,
                                    args.data_list_target_test,
                                    crop_size=input_size_test,
                                    set='val',
                                    transform=test_transform),
                            num_workers=args.num_workers,
                            batch_size=1,
                            shuffle=False, pin_memory=True)

    gt_imgs_test = open(label_path_list_test ,'r').read().splitlines()
    gt_imgs_test = [osp.join(args.data_dir_target_test, x) for x in gt_imgs_test]

    name_classes = np.array(info['label'], dtype=np.str)
    interp_val = nn.Upsample(size=(com_size[1], com_size[0]),mode='bilinear', align_corners=True)

    ####
    #build model
    ####
    builder = ModelBuilder()
    net_encoder = builder.build_encoder(
        arch=args.arch_encoder,
        fc_dim=args.fc_dim,
        weights=args.weights_encoder)
    net_decoder = builder.build_decoder(
        arch=args.arch_decoder,
        fc_dim=args.fc_dim,
        num_class=args.num_classes,
        weights=args.weights_decoder,
        use_aux=True)



    model = SegmentationModule(
        net_encoder, net_decoder, args.use_aux)

    if args.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
    model.cuda()

    nets = (net_encoder, net_decoder, None, None)
    optimizers = create_optimizer(nets, args)
    cudnn.enabled=True
    cudnn.benchmark=True
    model.train()



    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]


    source_normalize = transforms_seg.Normalize(mean=mean,
                                                std=std)

    mean_mapping = [0.485, 0.456, 0.406]
    mean_mapping = [item * 255 for item in mean_mapping]

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    source_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size[1], input_size[0]]),
                             segtransforms.RandScale((args.scale_min, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             segtransforms.RandomHorizontalFlip(),
                             segtransforms.Crop([input_size_crop[1], input_size_crop[0]], crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             source_normalize])
    target_normalize = transforms_seg.Normalize(mean=mean,
                                            std=std)
    target_transform = transforms_seg.Compose([
                             transforms_seg.Resize([input_size_target[1], input_size_target[0]]),
                             segtransforms.RandScale((args.scale_min, args.scale_max)),
                             #segtransforms.RandRotate((args.rotate_min, args.rotate_max), padding=mean_mapping, ignore_label=args.ignore_label),
                             #segtransforms.RandomGaussianBlur(),
                             segtransforms.RandomHorizontalFlip(),
                             segtransforms.Crop([input_size_target_crop[1], input_size_target_crop[0]],crop_type='rand', padding=mean_mapping, ignore_label=args.ignore_label),
                             transforms_seg.ToTensor(),
                             target_normalize])
    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size, transform = source_transform),
        batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(fake_cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     set=args.set,
                                                     transform=target_transform),
                                   batch_size=args.batch_size, shuffle=True, num_workers=1,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting


    criterion_seg = torch.nn.CrossEntropyLoss(ignore_index=255,reduce=False)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), align_corners=True, mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1


    optimizer_encoder, optimizer_decoder, optimizer_disc, optimizer_reconst = optimizers
    batch_time = AverageMeter(10)
    loss_seg_value1 = AverageMeter(10)
    is_best_test = True
    best_mIoUs = 0
    loss_seg_value2 = AverageMeter(10)
    loss_balance_value = AverageMeter(10)
    loss_pseudo_value = AverageMeter(10)
    bounding_num = AverageMeter(10)
    pseudo_num = AverageMeter(10)

    for i_iter in range(args.num_steps):
        # train G

        # don't accumulate grads in D

        end = time.time()
        _, batch = trainloader_iter.__next__()
        images, labels, _ = batch
        images = Variable(images).cuda(async=True)
        labels = Variable(labels).cuda(async=True)
        seg, aux_seg, loss_seg2, loss_seg1 = model(images, labels)


        loss_seg2 = torch.mean(loss_seg2)
        loss_seg1 = torch.mean(loss_seg1)
        loss = loss_seg2+args.lambda_seg*loss_seg1
        #logger.info(loss_seg1.data.cpu().numpy())
        loss_seg_value2.update(loss_seg2.data.cpu().numpy())
        # train with target
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        del seg, loss_seg2

        _, batch = targetloader_iter.__next__()
        with torch.no_grad():
            images, labels, _ = batch
            images = Variable(images).cuda(async=True)
            result = model(images, None)
            del result



        batch_time.update(time.time() - end)

        remain_iter = args.num_steps - i_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))




        adjust_learning_rate(optimizer_encoder, i_iter, args.lr_encoder, args)
        adjust_learning_rate(optimizer_decoder, i_iter, args.lr_decoder, args)
        if i_iter % args.print_freq == 0:
            lr_encoder = optimizer_encoder.param_groups[0]['lr']
            lr_decoder = optimizer_decoder.param_groups[0]['lr']
            logger.info('exp = {}'.format(args.snapshot_dir))
            logger.info('Iter = [{0}/{1}]\t'
                        'Time = {batch_time.avg:.3f}\t'
                        'loss_seg1 = {loss_seg1.avg:4f}\t'
                        'loss_seg2 = {loss_seg2.avg:.4f}\t'
                        'lr_encoder = {lr_encoder:.8f} lr_decoder = {lr_decoder:.8f}'.format(
                         i_iter, args.num_steps, batch_time=batch_time,
                         loss_seg1=loss_seg_value1, loss_seg2=loss_seg_value2,
                         lr_encoder=lr_encoder,
                         lr_decoder=lr_decoder))


            logger.info("remain_time: {}".format(remain_time))
            if not tb_logger is None:
                tb_logger.add_scalar('loss_seg_value1', loss_seg_value1.avg, i_iter)
                tb_logger.add_scalar('loss_seg_value2', loss_seg_value2.avg, i_iter)
                tb_logger.add_scalar('lr', lr_encoder, i_iter)
            #####
            #save image result

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                logger.info('taking snapshot ...')
                model.eval()

                val_time = time.time()
                hist = np.zeros((19,19))
                for index, batch in tqdm(enumerate(valloader)):
                    with torch.no_grad():
                        image, name = batch
                        output2, _ = model(Variable(image).cuda(), None)
                        pred = interp_val(output2)
                        del output2
                        pred = pred.cpu().data[0].numpy()
                        pred = pred.transpose(1, 2, 0)
                        pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8)
                        label = np.array(Image.open(gt_imgs_val[index]))
                        #label = np.array(label.resize(com_size, Image.
                        label = label_mapping(label, mapping)
                        #logger.info(label.shape)
                        hist += fast_hist(label.flatten(), pred.flatten(), 19)
                mIoUs = per_class_iu(hist)
                for ind_class in range(args.num_classes):
                    logger.info('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))
                    tb_logger.add_scalar(name_classes[ind_class] + '_mIoU', mIoUs[ind_class], i_iter)

                mIoUs = round(np.nanmean(mIoUs) *100, 2)
                if mIoUs >= best_mIoUs:
                    is_best_test = True
                    best_mIoUs = mIoUs
                else:
                    is_best_test = False

                logger.info("current mIoU {}".format(mIoUs))
                logger.info("best mIoU {}".format(best_mIoUs))
                tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                tb_logger.add_scalar('val mIoU', mIoUs, i_iter)
                net_encoder, net_decoder, net_disc, net_reconst = nets
                save_checkpoint(net_encoder, 'encoder', i_iter, args, is_best_test)
                save_checkpoint(net_decoder, 'decoder', i_iter, args, is_best_test)
            model.train()
Пример #23
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
        # model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    # # load discriminator params
    # saved_state_dict_D1 = torch.load(D1_RESTORE_FROM)
    # saved_state_dict_D2 = torch.load(D2_RESTORE_FROM)
    # model_D1.load_state_dict(saved_state_dict_D1)
    # model_D2.load_state_dict(saved_state_dict_D2)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.7, 0.99))
    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.7, 0.99))

    # opti_state_dict = torch.load(OPTI_RESTORE_FROM)
    # opti_state_dict_d1 = torch.load(OPTI_D1_RESTORE_FROM)
    # opti_state_dict_d2 = torch.load(OPTI_D2_RESTORE_FROM)
    # optimizer.load_state_dict(opti_state_dict)pt_original_crop_B0.7
    # optimizer_D1.load_state_dict(opti_state_dict_d1)
    # optimizer_D1.load_state_dict(opti_state_dict_d2)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 1
    target_label = 0
    mIoUs = []
    i_iters = []

    for i_iter in range(args.num_steps):
        if i_iter <= iter_start:
            continue

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, name = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            if i_iter % 2 == 0:

                loss_adv_target1 = bce_loss(
                    D_out1,
                    Variable(
                        torch.FloatTensor(
                            D_out1.data.size()).fill_(source_label)).cuda(
                                args.gpu))

                loss_adv_target2 = bce_loss(
                    D_out2,
                    Variable(
                        torch.FloatTensor(
                            D_out2.data.size()).fill_(source_label)).cuda(
                                args.gpu))

                loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
                loss = loss / args.iter_size
                loss.backward()
                loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
                ) / args.iter_size
                loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
                ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            weight_s = float(D_out2.mean().data.cpu().numpy())

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            weight_t = float(D_out2.mean().data.cpu().numpy())
            # if weight_b>0.5 and i_iter>500:
            #     confidence_map = interp(D_out2).cpu().data[0][0].numpy()
            #     name = name[0].split('/')[-1]
            #     confidence_map=255*confidence_map
            #     confidence_output=Image.fromarray(confidence_map.astype(np.uint8))
            #     confidence_output.save('./result/confid_map/%s.png' % (name.split('.')[0]))
            #     zq=1
            print(weight_s, weight_t)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        if i_iter % 3 == 0:
            optimizer_D1.step()
            optimizer_D2.step()

        # print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            torch.save(
                optimizer.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer.pth'))
            torch.save(
                optimizer_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer_D1.pth'))
            torch.save(
                optimizer_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(i_iter) + '_optimizer_D2.pth'))
            show_pred_sv_dir = pre_sv_dir.format(i_iter)
            mIoU = show_val(model.state_dict(), show_pred_sv_dir)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            i_iters.append(i_iter)
            print_i = 0
            for miou in mIoUs:
                print('i{0}: {1}'.format(i_iters[print_i], miou))
                print_i = print_i + 1
Пример #24
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)

    ############################
    #validation data
    testloader = data.DataLoader(cityscapesDataSet(args.data_dir_target,
                                                   args.data_list_target_val,
                                                   crop_size=input_size,
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set_val),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list = './dataset/cityscapes_list/label.txt'
    gt_imgs = open(label_path_list, 'r').read().splitlines()
    gt_imgs = [join('./data/Cityscapes/data/gtFine/val', x) for x in gt_imgs]

    interp_val = nn.UpsamplingBilinear2d(size=(com_size[1], com_size[0]))

    ############################

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.UpsamplingBilinear2d(size=(input_size[1], input_size[0]))
    interp_target = nn.UpsamplingBilinear2d(size=(input_size_target[1],
                                                  input_size_target[0]))

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G
            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.item() / args.iter_size
            loss_seg_value2 += loss_seg2.data.item() / args.iter_size

            # train with target

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.item(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.item(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.item()
            loss_D_value2 += loss_D2.data.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.item()
            loss_D_value2 += loss_D2.data.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))

            hist = np.zeros((19, 19))

            f = open(args.results_dir, 'a')
            for index, batch in enumerate(testloader):
                print(index)
                image, _, name = batch
                output1, output2 = model(
                    Variable(image, volatile=True).cuda(args.gpu))
                pred = interp_val(output2)
                pred = pred[0].permute(1, 2, 0)
                print(pred.shape)
                pred = torch.max(pred, 2)[1].byte()
                pred = pred.data.cpu().numpy()
                label = Image.open(gt_imgs[index])
                label = np.array(label.resize(com_size, Image.NEAREST))
                label = label_mapping(label, mapping)
                hist += fast_hist(label.flatten(), pred.flatten(), 19)

            mIoUs = per_class_iu(hist)
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            print(mIoU)
            f.write('i_iter:{:d},        miou:{:0.3f} \n'.format(i_iter, mIoU))
            f.close()
Пример #25
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           vgg16_caffe_path='./model/vgg16_init.pth',
                           pretrained=True)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=2048).to(device)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=1024).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feature, prediction = model(images)
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward()
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feature_target, _ = model(images)
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward()
        loss_adv_target_value = loss_adv_target.item()

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    if args.tensorboard:
        writer.close()
Пример #26
0
def main():
    """Create the model and start the training."""
    if RESTART:
        args.snapshot_dir = RESTART_FROM
    else:
        args.snapshot_dir = generate_snapshot_name(args)

    args_dict = vars(args)
    import json

    ###### load args for restart ######
    if RESTART:
        # pdb.set_trace()
        args_dict_file = args.snapshot_dir + 'args_dict_{}.json'.format(
            RESTART_ITER)
        with open(args_dict_file) as f:
            args_dict_last = json.load(f)
        for arg in args_dict:
            args_dict[arg] = args_dict_last[arg]

    ###### load args for restart ######

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    cudnn.benchmark = True

    if args.model == 'DeepLab':
        #model = DeeplabMulti(num_classes=args.num_classes)
        model = DeeplabMulti(num_classes=args.num_classes)
        width = 1024
        srcweight = 3
        model = MDD(width=width,
                    use_bottleneck=False,
                    use_gpu=not args.cpu,
                    class_num=args.num_classes,
                    srcweight=srcweight,
                    args=args)
        model.set_train(True)

    model.c_net.to(device)

    #### From here, code should not be related to model reload ####
    # but we would need hyperparameters: n_iter,
    # [lr, momentum, weight_decay, betas](these are all in args)
    # args.snapshot_dir = generate_snapshot_name()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    # pdb.set_trace()
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.get_parameter_list(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    lr_scheduler = INVScheduler(gamma=0.001, decay_rate=0.75, init_lr=0.004)

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.start_steps, args.num_steps):

        param_groups = model.get_parameter_list()
        group_ratios = [group['lr'] for group in param_groups]
        optimizer = lr_scheduler.next_optimizer(group_ratios, optimizer,
                                                i_iter / 5)
        optimizer.zero_grad()
        #adjust_learning_rate(optimizer, i_iter)

        total_loss = 0
        for sub_i in range(args.iter_size):

            # train with source

            _, batch = trainloader_iter.__next__()

            # load src
            src_images, src_labels, _, _ = batch
            src_images = src_images.to(device)
            src_labels = src_labels.long().to(device)

            # load target
            _, batch = targetloader_iter.__next__()
            tgt_images, _, _ = batch
            tgt_images = tgt_images.to(device)

            inputs = torch.cat((src_images, tgt_images), dim=0)

            loss = model.get_loss(inputs, src_labels)
            loss = loss / args.iter_size
            loss.backward()
            total_loss += loss.item()

        optimizer.step()

        scalar_info = {
            'loss': total_loss,
        }

        if i_iter % 10 == 0:
            for key, val in scalar_info.items():
                writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss = {2:.3f}'.format(
            i_iter, args.num_steps, total_loss))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.c_net.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.c_net.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

            ###### also record latest saved iteration #######
            args_dict['learning_rate'] = optimizer.param_groups[0]['lr']
            args_dict['start_steps'] = i_iter

            args_dict_file = args.snapshot_dir + '/args_dict_{}.json'.format(
                i_iter)
            with open(args_dict_file, 'w') as f:
                json.dump(args_dict, f)

            ###### also record latest saved iteration #######

    writer.close()
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        logger.info("adopting Deeplabv2 base model..")
        model = Res_Deeplab(num_classes=args.num_classes, multi_scale=False)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.model == "FCN8S":
        logger.info("adopting FCN8S base model..")
        from pytorchgo.model.MyFCN8s import MyFCN8s
        model = MyFCN8s(n_class=NUM_CLASSES)
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    else:
        raise ValueError

    model.train()
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda()

    model_D2.train()
    model_D2.cuda()

    if SOURCE_DATA == "GTA5":
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    elif SOURCE_DATA == "SYNTHIA":
        trainloader = data.DataLoader(SynthiaDataSet(
            args.data_dir,
            args.data_list,
            LABEL_LIST_PATH,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size,
            scale=args.random_scale,
            mirror=args.random_mirror,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_iter = enumerate(trainloader)
    else:
        raise ValueError

    targetloader = data.DataLoader(cityscapesDataSet(
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    best_mIoU = 0

    model_summary([model, model_D1, model_D2])
    optimizer_summary([optimizer, optimizer_D1, optimizer_D2])

    for i_iter in tqdm(range(args.num_steps_stop),
                       total=args.num_steps_stop,
                       desc="training"):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        lr_D1 = adjust_learning_rate_D(optimizer_D1, i_iter)
        lr_D2 = adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            ######################### train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.next()
            images, labels, _, _ = batch
            images = Variable(images).cuda()

            pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = loss_calc(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target

            _, batch = targetloader_iter.next()
            images, _, _, _ = batch
            images = Variable(images).cuda()

            pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            )[0] / args.iter_size

            ################################## train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2
            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda())

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        if i_iter % 100 == 0:
            logger.info(
                'iter = {}/{},loss_seg1 = {:.3f} loss_seg2 = {:.3f} loss_adv1 = {:.3f}, loss_adv2 = {:.3f} loss_D1 = {:.3f} loss_D2 = {:.3f}, lr={:.7f}, lr_D={:.7f}, best miou16= {:.5f}'
                .format(i_iter, args.num_steps_stop, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2,
                        lr, lr_D1, best_mIoU))

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            logger.info("saving snapshot.....")
            cur_miou16 = proceed_test(model, input_size)
            is_best = True if best_mIoU < cur_miou16 else False
            if is_best:
                best_mIoU = cur_miou16
            torch.save(
                {
                    'iteration': i_iter,
                    'optim_state_dict': optimizer.state_dict(),
                    'optim_D1_state_dict': optimizer_D1.state_dict(),
                    'optim_D2_state_dict': optimizer_D2.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'model_D1_state_dict': model_D1.state_dict(),
                    'model_D2_state_dict': model_D2.state_dict(),
                    'best_mean_iu': cur_miou16,
                }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
            if is_best:
                import shutil
                shutil.copy(
                    osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                    osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))

        if i_iter >= args.num_steps_stop - 1:
            break
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        # continue zq
        # model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D_a = FCDiscriminator(num_classes=256)  # need to check
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    # continue zq
    # d1_state_dict = torch.load(D1_RESTORE_FROM)
    # d2_state_dict = torch.load(D2_RESTORE_FROM)
    # model_D1.load_state_dict(d1_state_dict)
    # model_D2.load_state_dict(d2_state_dict)

    model_D_a.train()
    model_D_a.cuda(args.gpu)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A)):
        os.makedirs(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A))

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D_a = optim.Adam(model_D_a.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D_a.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # continue zq
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D_a.zero_grad()
        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D_a, i_iter)
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        mIoUs=[]

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D_a.parameters():
                param.requires_grad = False

            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_a, pred1, pred2 = model(images)
            pred1=nn.functional.interpolate(pred1,size=(input_size[1], input_size[0]), mode='bilinear',align_corners=True)
            pred2 = nn.functional.interpolate(pred2, size=(input_size[1], input_size[0]), mode='bilinear',
                                              align_corners=True)
            # pred1 = interp(pred1)
            # pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            _, batch = targetloader_iter.__next__()
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)



            lambda_wtight = (80000 - i_iter) / 80000

            if lambda_wtight > 0:

                pred_target_a, _, _ = model(images)

                D_out_a = model_D_a(pred_target_a)

                loss_adv_target_a = bce_loss(D_out_a,
                                         Variable(torch.FloatTensor(D_out_a.data.size()).fill_(source_label)).cuda(
                                             args.gpu))

                loss_adv_target_a=lambda_wtight*LAMBDA_ADV_TARGET_A*loss_adv_target_a
                loss_adv_target_a = loss_adv_target_a / args.iter_size
                loss_adv_target_a.backward()

            _, pred_target1, pred_target2 = model(images)


            pred_target1 = nn.functional.interpolate(pred_target1,size=(input_size_target[1], input_size_target[0]), mode='bilinear',align_corners=True)
            pred_target2 = nn.functional.interpolate(pred_target2, size=(input_size_target[1], input_size_target[0]),
                                                     mode='bilinear', align_corners=True)

            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D_a.parameters():
                param.requires_grad = True

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            lambda_wtight = (80000 - i_iter) / 80000
            if lambda_wtight > 0:
                pred_a=pred_a.detach()
                D_out_a = model_D_a(pred_a)
                loss_D_a = bce_loss(D_out_a,
                                    Variable(torch.FloatTensor(D_out_a.data.size()).fill_(source_label)).cuda(args.gpu))
                loss_D_a = loss_D_a / args.iter_size / 2
                loss_D_a.backward()


            pred1 = pred1.detach()
            pred2 = pred2.detach()


            D_out1 = model_D1(F.softmax(pred1,dim=1))
            D_out2 = model_D2(F.softmax(pred2,dim=1))



            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))


            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2


            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            lambda_wtight = (80000 - i_iter) / 80000
            if lambda_wtight > 0:
                pred_target_a=pred_target_a.detach()
                D_out_a = model_D_a(pred_target_a)
                loss_D_a = bce_loss(D_out_a,
                                    Variable(torch.FloatTensor(D_out_a.data.size()).fill_(target_label)).cuda(args.gpu))
                loss_D_a = loss_D_a / args.iter_size / 2
                loss_D_a.backward()


            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()


            D_out1 = model_D1(F.softmax(pred_target1,dim=1))
            D_out2 = model_D2(F.softmax(pred_target2,dim=1))


            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))


            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2


            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D_a.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A)))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A), 'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A), 'GTA5_' + str(args.num_steps) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A), 'GTA5_' + str(args.num_steps) + '_D2.pth'))
            show_val(model.state_dict(),LAMBDA_ADV_TARGET_A ,i_iter)
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A), 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A), 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir.format(LAMBDA_ADV_TARGET_A), 'GTA5_' + str(i_iter) + '_D2.pth'))
            mIoU=show_val(model.state_dict(), LAMBDA_ADV_TARGET_A,i_iter)
            mIoUs.append(str(round(np.nanmean(mIoU) * 100, 2)))
            for miou in mIoUs:
                print(miou)
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Implemented by Bongjoon Hyun
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)
    #

    # Implemented by Bongjoon Hyun
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    #

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        # Implemented by Bongjoon Hyun
        bce_loss = torch.nn.MSELoss()
        #

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):
            # Implemented by Bongjoon Hyun
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = (loss_seg2 + args.lambda_seg * loss_seg1) / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_adv_target1 = bce_loss(D1_out, labels_source1)
            loss_adv_target2 = bce_loss(D2_out, labels_source2)

            loss = args.lambda_adv_target1 * loss_adv_target1 + \
                   args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size

            loss.backward()

            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D1_out = model_D1(F.softmax(pred1))
            D2_out = model_D2(F.softmax(pred2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_source1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_source2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_target1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(target_label)).cuda(args.gpu)
            labels_target2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(target_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_target1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_target2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()
            #

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print
            'save model ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print
            'taking snapshot ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
Пример #30
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes,
                             train_bn=False,
                             norm_style='in')
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda()

    testloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                             args.data_list,
                                             crop_size=(640, 1280),
                                             resize_size=(1280, 640),
                                             mean=IMG_MEAN,
                                             scale=False,
                                             mirror=False),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, batch in enumerate(testloader):
        if (index * batchsize) % 100 == 0:
            print('%d processd' % (index * batchsize))
        image, _, _, name = batch
        print(image.shape)

        inputs = Variable(image).cuda()
        if args.model == 'DeeplabMulti':
            output1, output2 = model(inputs)
            output_batch = interp(sm(0.5 * output1 +
                                     output2)).cpu().data.numpy()
            #output1, output2 = model(fliplr(inputs))
            #output2 = fliplr(output2)
            #output_batch += interp(output2).cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            output.save('%s/%s' % (args.save, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (args.save, name_tmp.split('.')[0]))

    return args.save