def get_dataset(opts):
    """ Dataset And Augmentation
    """
    if opts.dataset == 'voc':
        train_transform = et.ExtCompose([
            #et.ExtResize(size=opts.crop_size),
            et.ExtRandomScale((0.5, 2.0)),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])
        if opts.crop_val:
            val_transform = et.ExtCompose([
                et.ExtResize(opts.crop_size),
                et.ExtCenterCrop(opts.crop_size),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        else:
            val_transform = et.ExtCompose([
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        train_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
                                    image_set='train', download=opts.download, transform=train_transform)
        val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
                                  image_set='val', download=False, transform=val_transform)

    if opts.dataset == 'cityscapes':
        train_transform = et.ExtCompose([
            #et.ExtResize( 512 ),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
            et.ExtColorJitter( brightness=0.5, contrast=0.5, saturation=0.5 ),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        val_transform = et.ExtCompose([
            et.ExtResize( 256 ),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        train_dst = Cityscapes(root=opts.data_root,
                               split='train', transform=train_transform)
        val_dst = Cityscapes(root=opts.data_root,
                             split='val', transform=val_transform)
    return train_dst, val_dst
Ejemplo n.º 2
0
def get_dataset(opts):
    """ Dataset And Augmentation
    """
    if opts.dataset=='voc':
        train_transform = ExtCompose( [ 
            ExtRandomScale((0.5, 2.0)),
            ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
            ExtRandomHorizontalFlip(),
            ExtToTensor(),
            ExtNormalize( mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225] ),
        ])

        if opts.crop_val:
            val_transform = ExtCompose([
                ExtResize(size=opts.crop_size),
                ExtCenterCrop(size=opts.crop_size),
                ExtToTensor(),
                ExtNormalize( mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225] ),
            ])
        else:
            # no crop, batch size = 1
            val_transform = ExtCompose([
                ExtToTensor(),
                ExtNormalize( mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225] ),
            ])
    
        train_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='train', download=opts.download, transform=train_transform)
        val_dst = VOCSegmentation(root=opts.data_root, year=opts.year, image_set='val', download=False, transform=val_transform)
        
    if opts.dataset=='cityscapes':
        train_transform = ExtCompose( [ 
            ExtScale(0.5),
            ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
            ExtRandomHorizontalFlip(),
            ExtToTensor(),
            ExtNormalize( mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225] ),
        ] )

        val_transform = ExtCompose( [
            ExtScale(0.5),
            ExtToTensor(),
            ExtNormalize( mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225] ),
        ] )

        train_dst = Cityscapes(root=opts.data_root, split='train', download=opts.download, target_type='semantic',  transform=train_transform)
        val_dst = Cityscapes(root=opts.data_root, split='test', target_type='semantic', download=False, transform=val_transform)
    return train_dst, val_dst
Ejemplo n.º 3
0
    def func_per_iteration(self, data, device):
        img = data['data']
        label = data['label']
        name = data['fn']
        # label = label - 1

        pred = self.sliding_eval(img, config.eval_crop_size,
                                 config.eval_stride_rate, device)
        hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes,
                                                       pred, label)
        results_dict = {
            'hist': hist_tmp,
            'labeled': labeled_tmp,
            'correct': correct_tmp
        }

        if self.save_path is not None:
            fn = name + '.png'
            pred, fn = Cityscapes.transform_label(pred, fn)
            cv2.imwrite(os.path.join(self.save_path, fn), pred)
            logger.info('Save the image ' + fn)

        if self.show_image:
            colors = self.dataset.get_class_colors
            image = img
            clean = np.zeros(label.shape)
            comp_img = show_img(colors, config.background, image, clean, label,
                                pred)
            cv2.imshow('comp_image', comp_img)
            cv2.waitKey(0)

        return results_dict
Ejemplo n.º 4
0
def main():
    model = torch.load('unet.pth').to(device)

    # dataset = Cityscapes(opt.root, split='val', resize=opt.resize, crop=opt.crop)
    dataset = Cityscapes(opt.root, resize=opt.resize, crop=opt.crop)

    inputs, labels = random.choice(dataset)

    inputs = inputs.unsqueeze(0)
    labels = labels.unsqueeze(0)

    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = model(inputs)
    outputs = outputs.detach()

    gt = classes_to_rgb(labels[0], dataset)
    seg = classes_to_rgb(outputs[0], dataset)

    fig, ax = plt.subplots(1, 3)

    ax[0].imshow(inputs[0].permute(1, 2, 0))
    ax[1].imshow(gt.permute(1, 2, 0))
    ax[2].imshow(seg.permute(1, 2, 0))

    plt.show()
Ejemplo n.º 5
0
def inference(test_img_dir, test_gt_dir):
    net = ResUnet('resnet34')
    net.load_state_dict(torch.load('./resunet.pth'))
    net.eval()
    test_file_list = load_sem_seg(test_img_dir, test_gt_dir)
    test_dataset = Cityscapes(test_file_list)
    testloader = DataLoader(test_dataset, batch_size=1,
                            shuffle=False, num_workers=8, pin_memory=True)
    img, label = next(iter(testloader))
    pred = net(img)
    pred = pred.squeeze().detach().numpy()

    print(pred)
Ejemplo n.º 6
0
def main():
    if opt.is_continue:
        model = torch.load('unet.pth').to(device)
    else:
        model = Unet(19).to(device)

    dataset = Cityscapes(opt.root, resize=opt.resize, crop=opt.crop)
    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=1)

    criterion = BCELoss().to(device)
    optimizer = Adam(model.parameters(), lr=0.001)

    t_now = time.time()

    for epoch in range(opt.n_epochs):
        print('epoch {}'.format(epoch))
        for i, batch in enumerate(dataloader):
            inputs, labels = batch

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(loss)
                print('time:', time.time() - t_now)
                t_now = time.time()

        print(loss)

        torch.save(model, 'unet.pth')
Ejemplo n.º 7
0
def main():
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)
    distill_criterion = nn.KLDivLoss()

    # data loader ###########################
    if config.is_test:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_eval_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }
    else:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }

    train_loader = get_train_loader(config, Cityscapes, test=config.is_test)

    # Model #######################################
    models = []
    evaluators = []
    testers = []
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(
                os.path.join(config.load_path, "arch_%d.pt" % arch_idx))
        else:
            state = torch.load(
                os.path.join(
                    config.load_path,
                    "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch))))

        model = Network([
            state["alpha_%d_0" % arch_idx].detach(),
            state["alpha_%d_1" % arch_idx].detach(),
            state["alpha_%d_2" % arch_idx].detach()
        ], [
            None, state["beta_%d_1" % arch_idx].detach(),
            state["beta_%d_2" % arch_idx].detach()
        ], [
            state["ratio_%d_0" % arch_idx].detach(),
            state["ratio_%d_1" % arch_idx].detach(),
            state["ratio_%d_2" % arch_idx].detach()
        ],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=config.stem_head_width[idx],
                        ignore_skip=arch_idx == 0)

        mIoU02 = state["mIoU02"]
        latency02 = state["latency02"]
        obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]
        latency12 = state["latency12"]
        obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12: last = [2, 0]
        else: last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        width=getattr(model, "widths%d" % b),
                        head_width=config.stem_head_width[idx][1],
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(
            os.path.join(config.save, "path_width%d.png" % arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0],
                        [model.widths2, model.widths1, model.widths0]).savefig(
                            os.path.join(config.save,
                                         "path_width_all%d.png" % arch_idx))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ))
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        model = model.cuda()
        init_weight(model,
                    nn.init.kaiming_normal_,
                    torch.nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')

        if arch_idx == 0 and len(config.arch_idx) > 1:
            partial = torch.load(
                os.path.join(config.teacher_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)
        elif config.is_eval:
            partial = torch.load(
                os.path.join(config.eval_path, "weights%d.pt" % arch_idx))
            state = model.state_dict()
            pretrained_dict = {k: v for k, v in partial.items() if k in state}
            state.update(pretrained_dict)
            model.load_state_dict(state)

        evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                                 config.num_classes,
                                 config.image_mean,
                                 config.image_std,
                                 model,
                                 config.eval_scale_array,
                                 config.eval_flip,
                                 0,
                                 out_idx=0,
                                 config=config,
                                 verbose=False,
                                 save_path=None,
                                 show_image=False)
        evaluators.append(evaluator)
        tester = SegTester(Cityscapes(data_setting, 'test', None),
                           config.num_classes,
                           config.image_mean,
                           config.image_std,
                           model,
                           config.eval_scale_array,
                           config.eval_flip,
                           0,
                           out_idx=0,
                           config=config,
                           verbose=False,
                           save_path=None,
                           show_image=False)
        testers.append(tester)

        # Optimizer ###################################
        base_lr = config.lr
        if arch_idx == 1 or len(config.arch_idx) == 1:
            # optimize teacher solo OR student (w. distill from teacher)
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=base_lr,
                                        momentum=config.momentum,
                                        weight_decay=config.weight_decay)
        models.append(model)

    # Cityscapes ###########################################
    if config.is_eval:
        logging.info(config.load_path)
        logging.info(config.eval_path)
        logging.info(config.save)
        # validation
        print("[validation...]")
        with torch.no_grad():
            valid_mIoUs = infer(models, evaluators, logger)
            for idx, arch_idx in enumerate(config.arch_idx):
                if arch_idx == 0:
                    logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx], 0)
                    logging.info("teacher's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
                else:
                    logger.add_scalar("mIoU/val_student", valid_mIoUs[idx], 0)
                    logging.info("student's valid_mIoU %.3f" %
                                 (valid_mIoUs[idx]))
        exit(0)

    tbar = tqdm(range(config.nepochs), ncols=80)
    for epoch in tbar:
        logging.info(config.load_path)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))
        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train_mIoUs = train(train_loader, models, ohem_criterion,
                            distill_criterion, optimizer, logger, epoch)
        torch.cuda.empty_cache()
        for idx, arch_idx in enumerate(config.arch_idx):
            if arch_idx == 0:
                logger.add_scalar("mIoU/train_teacher", train_mIoUs[idx],
                                  epoch)
                logging.info("teacher's train_mIoU %.3f" % (train_mIoUs[idx]))
            else:
                logger.add_scalar("mIoU/train_student", train_mIoUs[idx],
                                  epoch)
                logging.info("student's train_mIoU %.3f" % (train_mIoUs[idx]))
        adjust_learning_rate(base_lr, 0.992, optimizer, epoch + 1,
                             config.nepochs)

        # validation
        if not config.is_test and ((epoch + 1) % 10 == 0 or epoch == 0):
            tbar.set_description("[Epoch %d/%d][validation...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                valid_mIoUs = infer(models, evaluators, logger)
                for idx, arch_idx in enumerate(config.arch_idx):
                    if arch_idx == 0:
                        logger.add_scalar("mIoU/val_teacher", valid_mIoUs[idx],
                                          epoch)
                        logging.info("teacher's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    else:
                        logger.add_scalar("mIoU/val_student", valid_mIoUs[idx],
                                          epoch)
                        logging.info("student's valid_mIoU %.3f" %
                                     (valid_mIoUs[idx]))
                    save(models[idx],
                         os.path.join(config.save, "weights%d.pt" % arch_idx))
        # test
        if config.is_test and (epoch + 1) >= 250 and (epoch + 1) % 10 == 0:
            tbar.set_description("[Epoch %d/%d][test...]" %
                                 (epoch + 1, config.nepochs))
            with torch.no_grad():
                test(epoch, models, testers, logger)

        for idx, arch_idx in enumerate(config.arch_idx):
            save(models[idx],
                 os.path.join(config.save, "weights%d.pt" % arch_idx))
Ejemplo n.º 8
0
                        default=False,
                        action='store_true')
    parser.add_argument('--save_path', '-p', default=None)

    args = parser.parse_args()
    all_dev = parse_devices(args.devices)

    if config.is_test:
        eval_source = config.test_source
    else:
        eval_source = config.eval_source

    mp_ctx = mp.get_context('spawn')
    network = CPNet(config.num_classes, criterion=None)
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': eval_source
    }
    dataset = Cityscapes(data_setting, 'val', None)

    with torch.no_grad():
        segmentor = SegEvaluator(dataset, config.num_classes,
                                 config.image_mean, config.image_std, network,
                                 config.eval_scale_array, config.eval_flip,
                                 all_dev, args.verbose, args.save_path,
                                 args.show_image)
        segmentor.run(config.snapshot_dir, args.epochs, config.val_log_file,
                      config.link_val_log_file)
Ejemplo n.º 9
0
def train(network,
          train_img_dir,
          train_gt_dir,
          val_img_dir,
          val_gt_dir,
          lr,
          epochs,
          lbl_conversion=False):

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    net = ResUnet(network)
    net.to(device).train()

    optimizer = optim.SGD(net.head.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    if lbl_conversion:
        label_conversion(train_gt_dir)
        label_conversion(val_gt_dir)

    train_file_list = load_sem_seg(train_img_dir, train_gt_dir)
    train_dataset = Cityscapes(train_file_list)
    trainloader = DataLoader(train_dataset,
                             batch_size=25,
                             shuffle=True,
                             num_workers=8,
                             pin_memory=True)

    val_file_list = load_sem_seg(val_img_dir, val_gt_dir)
    val_dataset = Cityscapes(val_file_list)
    valloader = DataLoader(val_dataset,
                           batch_size=1,
                           shuffle=False,
                           num_workers=8,
                           pin_memory=True)
    check_point_path = './checkpoints/'

    print('Begin of the Training')
    for epoch in range(epochs):
        running_loss = 0.0
        for idx, data in enumerate(trainloader):
            optimizer.zero_grad()
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            print(epoch, idx, loss.item())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if idx % 20 == 19:  # print every 20 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, idx + 1, running_loss / 2000))
                running_loss = 0.0

    print('End of the Training')

    PATH = './resunet.pth'
    torch.save(net.state_dict(), PATH)
Ejemplo n.º 10
0
def main():
    args, args_text = _parse_args()

    # dist init
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='tcp://127.0.0.1:26442',
                                         world_size=1,
                                         rank=0)
    config.device = 'cuda:%d' % args.local_rank
    torch.cuda.set_device(args.local_rank)
    args.world_size = torch.distributed.get_world_size()
    args.local_rank = torch.distributed.get_rank()
    logging.info("rank: {} world_size: {}".format(args.local_rank,
                                                  args.world_size))

    if args.local_rank == 0:
        create_exp_dir(config.save,
                       scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
        logger = SummaryWriter(config.save)
        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout,
                            level=logging.INFO,
                            format=log_format,
                            datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", str(config))
    else:
        logger = None

    # preparation ################
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # data loader ###########################
    if config.is_test:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_eval_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }
    else:
        data_setting = {
            'img_root': config.img_root_folder,
            'gt_root': config.gt_root_folder,
            'train_source': config.train_source,
            'eval_source': config.eval_source,
            'test_source': config.test_source,
            'down_sampling': config.down_sampling
        }

    with open(config.json_file, 'r') as f:
        model_dict = json.loads(f.read())

    model = Network(model_dict["ops"],
                    model_dict["paths"],
                    model_dict["downs"],
                    model_dict["widths"],
                    model_dict["lasts"],
                    num_classes=config.num_classes,
                    layers=config.layers,
                    Fch=config.Fch,
                    width_mult_list=config.width_mult_list,
                    stem_head_width=config.stem_head_width)

    if args.local_rank == 0:
        logging.info("net: " + str(model))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ),
                                verbose=False)
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        with open(os.path.join(config.save, 'args.yaml'), 'w') as f:
            f.write(args_text)

    model = model.cuda()
    init_weight(model,
                nn.init.kaiming_normal_,
                torch.nn.BatchNorm2d,
                config.bn_eps,
                config.bn_momentum,
                mode='fan_in',
                nonlinearity='relu')

    model = load_pretrain(model, config.model_path)

    # partial = torch.load(config.model_path)
    # state = model.state_dict()
    # pretrained_dict = {k: v for k, v in partial.items() if k in state}
    # state.update(pretrained_dict)
    # model.load_state_dict(state)

    eval_model = model
    evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                             config.num_classes,
                             config.image_mean,
                             config.image_std,
                             eval_model,
                             config.eval_scale_array,
                             config.eval_flip,
                             0,
                             out_idx=0,
                             config=config,
                             verbose=False,
                             save_path=None,
                             show_image=False,
                             show_prediction=False)
    tester = SegTester(Cityscapes(data_setting, 'test', None),
                       config.num_classes,
                       config.image_mean,
                       config.image_std,
                       eval_model,
                       config.eval_scale_array,
                       config.eval_flip,
                       0,
                       out_idx=0,
                       config=config,
                       verbose=False,
                       save_path=None,
                       show_prediction=False)

    # Cityscapes ###########################################
    logging.info(config.model_path)
    logging.info(config.save)
    with torch.no_grad():
        if config.is_test:
            # test
            print("[test...]")
            with torch.no_grad():
                test(0, model, tester, logger)
        else:
            # validation
            print("[validation...]")
            valid_mIoU = infer(model, evaluator, logger)
            logger.add_scalar("mIoU/val", valid_mIoU, 0)
            logging.info("Model valid_mIoU %.3f" % (valid_mIoU))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    tau = torch.ones(1) * args.tau
    tau = tau.cuda(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params, False)
    elif args.model == 'DeepLabVGG':
        model = DeeplabVGG(pretrained=True, num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    weak_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Resize(1024),
        transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #         RandomCrop(768)
    ])

    target_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Normalize(mean, std)
        #         transforms.Resize(1024),
        #         transforms.ToTensor(),
        #         RandomCrop(768)
    ])

    label_set = GTA5(
        root=args.data_dir,
        num_cls=19,
        split='all',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size,
        #              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )
    unlabel_set = Cityscapes(
        root=args.data_dir_target,
        split=args.set,
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )

    test_set = Cityscapes(
        root=args.data_dir_target,
        split='val',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                       crop_transform=RandomCrop(768)
    )

    label_loader = data.DataLoader(label_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=False)

    unlabel_loader = data.DataLoader(unlabel_set,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=False)

    test_loader = data.DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=False)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    [model, model_D2,
     model_D2], [optimizer, optimizer_D1, optimizer_D2
                 ] = amp.initialize([model, model_D2, model_D2],
                                    [optimizer, optimizer_D1, optimizer_D2],
                                    opt_level="O1",
                                    num_losses=7)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = Interpolate(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = Interpolate(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_test = Interpolate(size=(input_size_target[1],
                                    input_size_target[0]),
                              mode='bilinear',
                              align_corners=True)
    #     interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True)

    normalize_transform = transforms.Compose([
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    # labels for adversarial training
    source_label = 0
    target_label = 1

    max_mIoU = 0

    total_loss_seg_value1 = []
    total_loss_adv_target_value1 = []
    total_loss_D_value1 = []
    total_loss_con_value1 = []

    total_loss_seg_value2 = []
    total_loss_adv_target_value2 = []
    total_loss_D_value2 = []
    total_loss_con_value2 = []

    hist = np.zeros((num_cls, num_cls))

    #     for i_iter in range(args.num_steps):
    for i_iter, (batch, batch_un) in enumerate(
            zip(roundrobin_infinite(label_loader),
                roundrobin_infinite(unlabel_loader))):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_con_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_con_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D1.parameters():
            param.requires_grad = False

        for param in model_D2.parameters():
            param.requires_grad = False

        # train with source

        images, labels = batch
        images_orig = images
        images = transform_batch(images, normalize_transform)
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

#         loss.backward()
        loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

        # train with target

        images_tar, labels_tar = batch_un
        images_tar_orig = images_tar
        images_tar = transform_batch(images_tar, normalize_transform)
        images_tar = Variable(images_tar).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_tar)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_adv_target1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv_target2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()
        loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
        ) / args.iter_size
        loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
        ) / args.iter_size

        # train with consistency loss
        # unsupervise phase
        policies = RandAugment().get_batch_policy(args.batch_size)
        rand_p1 = np.random.random(size=args.batch_size)
        rand_p2 = np.random.random(size=args.batch_size)
        random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2])

        images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1,
                                      rand_p2, random_dir)

        images_aug_orig = images_aug
        images_aug = transform_batch(images_aug, normalize_transform)
        images_aug = Variable(images_aug).cuda(args.gpu)

        pred_target_aug1, pred_target_aug2 = model(images_aug)
        pred_target_aug1 = interp_target(pred_target_aug1)
        pred_target_aug2 = interp_target(pred_target_aug2)

        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1)
        max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1)

        psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32)
        psuedo_label1_thre = psuedo_label1.copy()
        psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies,
                                             rand_p1, rand_p2, random_dir)
        psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32)
        psuedo_label2_thre = psuedo_label2.copy()
        psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies,
                                             rand_p1, rand_p2, random_dir)

        psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu)
        psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu)

        if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre,
                                  args.gpu)
            loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size
        else:
            loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre,
                                  args.gpu)
            loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size
        else:
            loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2
        # proper normalization
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()

# train D

# bring back requires_grad
        for param in model_D1.parameters():
            param.requires_grad = True

        for param in model_D2.parameters():
            param.requires_grad = True

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        D_out1 = model_D1(F.softmax(pred1, dim=1))
        D_out2 = model_D2(F.softmax(pred2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        # train with target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_con_value1,
                    loss_con_value2))

        total_loss_seg_value1.append(loss_seg_value1)
        total_loss_adv_target_value1.append(loss_adv_target_value1)
        total_loss_D_value1.append(loss_D_value1)
        total_loss_con_value1.append(loss_con_value1)

        total_loss_seg_value2.append(loss_seg_value2)
        total_loss_adv_target_value2.append(loss_adv_target_value2)
        total_loss_D_value2.append(loss_D_value2)
        total_loss_con_value2.append(loss_con_value2)

        hist += fast_hist(
            labels.cpu().numpy().flatten().astype(int),
            torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int),
            num_cls)

        if i_iter % 10 == 0:
            print('({}/{})'.format(i_iter + 1, int(args.num_steps)))
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(np.mean(iu)))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_train_acc = acc_overall
            avg_train_loss_seg1 = np.mean(total_loss_seg_value1)
            avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1)
            avg_train_loss_dis1 = np.mean(total_loss_D_value1)
            avg_train_loss_con1 = np.mean(total_loss_con_value1)
            avg_train_loss_seg2 = np.mean(total_loss_seg_value2)
            avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2)
            avg_train_loss_dis2 = np.mean(total_loss_D_value2)
            avg_train_loss_con2 = np.mean(total_loss_con_value2)

            print('avg_train_acc      :', avg_train_acc)
            print('avg_train_loss_seg1 :', avg_train_loss_seg1)
            print('avg_train_loss_adv1 :', avg_train_loss_adv1)
            print('avg_train_loss_dis1 :', avg_train_loss_dis1)
            print('avg_train_loss_con1 :', avg_train_loss_con1)
            print('avg_train_loss_seg2 :', avg_train_loss_seg2)
            print('avg_train_loss_adv2 :', avg_train_loss_adv2)
            print('avg_train_loss_dis2 :', avg_train_loss_dis2)
            print('avg_train_loss_con2 :', avg_train_loss_con2)

            writer['train'].add_scalar('log/mIoU', mIoU, i_iter)
            writer['train'].add_scalar('log/acc', avg_train_acc, i_iter)
            writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2,
                                       i_iter)

            hist = np.zeros((num_cls, num_cls))
            total_loss_seg_value1 = []
            total_loss_adv_target_value1 = []
            total_loss_D_value1 = []
            total_loss_con_value1 = []
            total_loss_seg_value2 = []
            total_loss_adv_target_value2 = []
            total_loss_D_value2 = []
            total_loss_con_value2 = []

            fig = plt.figure(figsize=(15, 15))

            labels = labels[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(331)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(337)
            images = images_orig[0].cpu().numpy().transpose((1, 2, 0))
            #             images += IMG_MEAN
            ax.imshow(images)
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(334)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            labels_tar = labels_tar[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(332)
            ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L')))
            ax.axis("off")
            ax.set_title('tar_labels')

            ax = fig.add_subplot(338)
            ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('tar_datas')

            _, pred_target2 = torch.max(pred_target2, dim=1)
            pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(335)
            ax.imshow(print_palette(
                Image.fromarray(pred_target2).convert('L')))
            ax.axis("off")
            ax.set_title('tar_predicts')

            print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0],
                  'random_dir', random_dir[0])

            psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(333)
            ax.imshow(
                print_palette(
                    Image.fromarray(psuedo_label2_thre).convert('L')))
            ax.axis("off")
            ax.set_title('psuedo_labels')

            ax = fig.add_subplot(339)
            ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('aug_datas')

            _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1)
            pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(336)
            ax.imshow(
                print_palette(Image.fromarray(pred_target_aug2).convert('L')))
            ax.axis("off")
            ax.set_title('aug_predicts')

            #             plt.show()
            writer['train'].add_figure('image/',
                                       fig,
                                       global_step=i_iter,
                                       close=True)

        if i_iter % 500 == 0:
            loss1 = []
            loss2 = []
            for test_i, batch in enumerate(test_loader):

                images, labels = batch
                images_orig = images
                images = transform_batch(images, normalize_transform)
                images = Variable(images).cuda(args.gpu)

                pred1, pred2 = model(images)
                pred1 = interp_test(pred1)
                pred1 = pred1.detach()
                pred2 = interp_test(pred2)
                pred2 = pred2.detach()

                loss_seg1 = loss_calc(pred1, labels, args.gpu)
                loss_seg2 = loss_calc(pred2, labels, args.gpu)
                loss1.append(loss_seg1.item())
                loss2.append(loss_seg2.item())

                hist += fast_hist(
                    labels.cpu().numpy().flatten().astype(int),
                    torch.argmax(pred2,
                                 dim=1).cpu().numpy().flatten().astype(int),
                    num_cls)

            print('test')
            fig = plt.figure(figsize=(15, 15))
            labels = labels[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(311)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(313)
            ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(312)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            #             plt.show()

            writer['test'].add_figure('test_image/',
                                      fig,
                                      global_step=i_iter,
                                      close=True)

            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(mIoU))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_test_loss1 = np.mean(loss1)
            avg_test_loss2 = np.mean(loss2)
            avg_test_acc = acc_overall
            print('avg_test_loss2 :', avg_test_loss1)
            print('avg_test_loss1 :', avg_test_loss2)
            print('avg_test_acc   :', avg_test_acc)
            writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter)
            writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter)
            writer['test'].add_scalar('log/acc', avg_test_acc, i_iter)
            writer['test'].add_scalar('log/mIoU', mIoU, i_iter)

            hist = np.zeros((num_cls, num_cls))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if max_mIoU < mIoU:
            max_mIoU = mIoU
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
Ejemplo n.º 12
0
def main(pretrain=True):
    config.save = 'search-{}-{}'.format(config.save,
                                        time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))
    logger = SummaryWriter(config.save)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(config.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    assert type(pretrain) == bool or type(pretrain) == str
    update_arch = True
    if pretrain == True:
        update_arch = False
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))
    ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255,
                                            thresh=0.7,
                                            min_kept=min_kept,
                                            use_weight=False)

    # Model #######################################
    model = Network(config.num_classes,
                    config.layers,
                    ohem_criterion,
                    Fch=config.Fch,
                    width_mult_list=config.width_mult_list,
                    prun_modes=config.prun_modes,
                    stem_head_width=config.stem_head_width)
    flops, params = profile(model,
                            inputs=(torch.randn(1, 3, 1024, 2048), ),
                            verbose=False)
    logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
    model = model.cuda()
    if type(pretrain) == str:
        partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0')
        state = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in partial.items()
            if k in state and state[k].size() == partial[k].size()
        }
        state.update(pretrained_dict)
        model.load_state_dict(state)
    else:
        init_weight(model,
                    nn.init.kaiming_normal_,
                    nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    architect = Architect(model, config)

    # Optimizer ###################################
    base_lr = config.lr
    parameters = []
    parameters += list(model.stem.parameters())
    parameters += list(model.cells.parameters())
    parameters += list(model.refine32.parameters())
    parameters += list(model.refine16.parameters())
    parameters += list(model.head0.parameters())
    parameters += list(model.head1.parameters())
    parameters += list(model.head2.parameters())
    parameters += list(model.head02.parameters())
    parameters += list(model.head12.parameters())
    optimizer = torch.optim.SGD(parameters,
                                lr=base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    # lr policy ##############################
    lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978)

    # data loader ###########################
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': config.eval_source,
        'down_sampling': config.down_sampling
    }
    train_loader_model = get_train_loader(config,
                                          Cityscapes,
                                          portion=config.train_portion)
    train_loader_arch = get_train_loader(config,
                                         Cityscapes,
                                         portion=config.train_portion - 1)

    evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                             config.num_classes,
                             config.image_mean,
                             config.image_std,
                             model,
                             config.eval_scale_array,
                             config.eval_flip,
                             0,
                             config=config,
                             verbose=False,
                             save_path=None,
                             show_image=False)

    if update_arch:
        for idx in range(len(config.latency_weight)):
            logger.add_scalar("arch/latency_weight%d" % idx,
                              config.latency_weight[idx], 0)
            logging.info("arch_latency_weight%d = " % idx +
                         str(config.latency_weight[idx]))

    tbar = tqdm(range(config.nepochs), ncols=80)
    valid_mIoU_history = []
    FPSs_history = []
    latency_supernet_history = []
    latency_weight_history = []
    valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"]
    arch_names = {0: "teacher", 1: "student"}
    for epoch in tbar:
        logging.info(pretrain)
        logging.info(config.save)
        logging.info("lr: " + str(optimizer.param_groups[0]['lr']))

        logging.info("update arch: " + str(update_arch))

        # training
        tbar.set_description("[Epoch %d/%d][train...]" %
                             (epoch + 1, config.nepochs))
        train(pretrain,
              train_loader_model,
              train_loader_arch,
              model,
              architect,
              ohem_criterion,
              optimizer,
              lr_policy,
              logger,
              epoch,
              update_arch=update_arch)
        torch.cuda.empty_cache()
        lr_policy.step()

        # validation
        tbar.set_description("[Epoch %d/%d][validation...]" %
                             (epoch + 1, config.nepochs))
        with torch.no_grad():
            if pretrain == True:
                model.prun_mode = "min"
                valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False)
                for i in range(5):
                    logger.add_scalar('mIoU/val_min_%s' % valid_names[i],
                                      valid_mIoUs[i], epoch)
                    logging.info("Epoch %d: valid_mIoU_min_%s %.3f" %
                                 (epoch, valid_names[i], valid_mIoUs[i]))
                if len(model._width_mult_list) > 1:
                    model.prun_mode = "max"
                    valid_mIoUs = infer(epoch,
                                        model,
                                        evaluator,
                                        logger,
                                        FPS=False)
                    for i in range(5):
                        logger.add_scalar('mIoU/val_max_%s' % valid_names[i],
                                          valid_mIoUs[i], epoch)
                        logging.info("Epoch %d: valid_mIoU_max_%s %.3f" %
                                     (epoch, valid_names[i], valid_mIoUs[i]))
                    model.prun_mode = "random"
                    valid_mIoUs = infer(epoch,
                                        model,
                                        evaluator,
                                        logger,
                                        FPS=False)
                    for i in range(5):
                        logger.add_scalar(
                            'mIoU/val_random_%s' % valid_names[i],
                            valid_mIoUs[i], epoch)
                        logging.info("Epoch %d: valid_mIoU_random_%s %.3f" %
                                     (epoch, valid_names[i], valid_mIoUs[i]))
            else:
                valid_mIoUss = []
                FPSs = []
                model.prun_mode = None
                for idx in range(len(model._arch_names)):
                    # arch_idx
                    model.arch_idx = idx
                    valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator,
                                                    logger)
                    valid_mIoUss.append(valid_mIoUs)
                    FPSs.append([fps0, fps1])
                    for i in range(5):
                        # preds
                        logger.add_scalar(
                            'mIoU/val_%s_%s' %
                            (arch_names[idx], valid_names[i]), valid_mIoUs[i],
                            epoch)
                        logging.info("Epoch %d: valid_mIoU_%s_%s %.3f" %
                                     (epoch, arch_names[idx], valid_names[i],
                                      valid_mIoUs[i]))
                    if config.latency_weight[idx] > 0:
                        logger.add_scalar(
                            'Objective/val_%s_8s_32s' % arch_names[idx],
                            objective_acc_lat(valid_mIoUs[3], 1000. / fps0),
                            epoch)
                        logging.info(
                            "Epoch %d: Objective_%s_8s_32s %.3f" %
                            (epoch, arch_names[idx],
                             objective_acc_lat(valid_mIoUs[3], 1000. / fps0)))
                        logger.add_scalar(
                            'Objective/val_%s_16s_32s' % arch_names[idx],
                            objective_acc_lat(valid_mIoUs[4], 1000. / fps1),
                            epoch)
                        logging.info(
                            "Epoch %d: Objective_%s_16s_32s %.3f" %
                            (epoch, arch_names[idx],
                             objective_acc_lat(valid_mIoUs[4], 1000. / fps1)))
                valid_mIoU_history.append(valid_mIoUss)
                FPSs_history.append(FPSs)
                if update_arch:
                    latency_supernet_history.append(architect.latency_supernet)
                latency_weight_history.append(architect.latency_weight)

        save(model, os.path.join(config.save, 'weights.pt'))
        if type(pretrain) == str:
            # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios}
            for idx, arch_name in enumerate(model._arch_names):
                state = {}
                for name in arch_name['alphas']:
                    state[name] = getattr(model, name)
                for name in arch_name['betas']:
                    state[name] = getattr(model, name)
                for name in arch_name['ratios']:
                    state[name] = getattr(model, name)
                state["mIoU02"] = valid_mIoUs[3]
                state["mIoU12"] = valid_mIoUs[4]
                if pretrain is not True:
                    state["latency02"] = 1000. / fps0
                    state["latency12"] = 1000. / fps1
                torch.save(
                    state,
                    os.path.join(config.save, "arch_%d_%d.pt" % (idx, epoch)))
                torch.save(state,
                           os.path.join(config.save, "arch_%d.pt" % (idx)))

        if update_arch:
            for idx in range(len(config.latency_weight)):
                if config.latency_weight[idx] > 0:
                    if (int(FPSs[idx][0] >= config.FPS_max[idx]) +
                            int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1:
                        architect.latency_weight[idx] /= 2
                    elif (int(FPSs[idx][0] <= config.FPS_min[idx]) +
                          int(FPSs[idx][1] <= config.FPS_min[idx])) > 0:
                        architect.latency_weight[idx] *= 2
                    logger.add_scalar(
                        "arch/latency_weight_%s" % arch_names[idx],
                        architect.latency_weight[idx], epoch + 1)
                    logging.info("arch_latency_weight_%s = " %
                                 arch_names[idx] +
                                 str(architect.latency_weight[idx]))
Ejemplo n.º 13
0
def main(model_name):
    # TODO: parse args.
    n_classes = 19
    #batch_size = 2
    batch_size = 1  #24
    n_workers = 12
    n_semantic_pretrain = 0  # 500 # First train only on semantics.
    n_epochs = 500
    validation_step = 15
    # TODO: implement resize as pil_transform
    resize = None  # (256, 512)
    cityscapes_directory = "/home/<someuser>/cityscapes"
    output_directory = "tmp/"
    # Uncomment next line when you've set all directories.
    raise ValueError("Please set the input/output directories.")
    checkpoint = None
    #checkpoint = (
    #        "weights/...pth",
    #        <fill in epoch>)

    # --- Setup loss functions.
    classification_loss = nn.CrossEntropyLoss(ignore_index=255)
    regression_loss = nn.MSELoss(reduction='elementwise_mean')

    print("--- Load model.")
    if model_name == 'DRNRegressionDownsampled':
        classification_loss = None
        regression_loss = nn.MSELoss(reduction='elementwise_mean')
        dataset_kwargs = {
            'pil_transforms':
            None,
            'gt_pil_transforms': [ModeDownsample(8)],
            'fit_gt_pil_transforms':
            [transforms.Resize(size=(784 // 8, 1792 // 8), interpolation=2)],
            'input_transforms': [
                transforms.Normalize(mean=[0.290101, 0.328081, 0.286964],
                                     std=[0.182954, 0.186566, 0.184475])
            ],
            'tensor_transforms':
            None
        }
        model = DRNRegressionDownsampled(
            model_name='drn_d_22',
            classes=n_classes,
            pretrained_dict=torch.load('./weights/drn_d_22_cityscapes.pth'))
        model.cuda()
        parameters = model.parameters()
    else:
        raise ValueError("Model \"{}\" not found!".format(model_name))

    optimizer = optim.Adam(parameters)
    start_epoch = 0

    if checkpoint is not None:
        print("Loading from checkpoint {}".format(checkpoint))
        model.load_state_dict(torch.load(checkpoint[0]))
        optimizer.load_state_dict(torch.load(checkpoint[1]))
        start_epoch = checkpoint[2] + 1

    print("--- Setup dataset and dataloaders.")
    train_set = Cityscapes(data_split='subtrain',
                           cityscapes_directory=cityscapes_directory,
                           **dataset_kwargs)
    train_loader = data.DataLoader(train_set,
                                   batch_size=batch_size,
                                   num_workers=n_workers,
                                   shuffle=True)

    val_set = Cityscapes(data_split='subtrainval',
                         cityscapes_directory=cityscapes_directory,
                         **dataset_kwargs)
    val_loader = data.DataLoader(val_set,
                                 batch_size=batch_size,
                                 num_workers=n_workers)

    # Sample 10 validation indices for visualization.
    #validation_idxs = np.random.choice(np.arange(len(val_set)),
    #                                   size=min(9, len(val_set)),
    #                                   replace=False)
    # Nah, let's pick them ourselves for now.
    validation_idxs = [17, 241, 287, 304, 123, 458, 1, 14, 139, 388]

    if True:
        print("--- Setup visual validation.")
        # Save them for comparison.
        check_mkdir('{}/validationimgs'.format(output_directory))
        check_mkdir('{}/offsets_gt'.format(output_directory))
        check_mkdir('{}/semantic_gt'.format(output_directory))
        for validation_idx in validation_idxs:
            img_pil, _, _ = val_set.load_fit_gt_PIL_images(validation_idx)
            img, semantic_gt, offset_gt = val_set[validation_idx]
            img_pil.save("{}/validationimgs/id{:03}.png".format(
                output_directory, validation_idx))
            visualize_semantics(
                img_pil, semantic_gt,
                "{}/semantic_gt/id{:03}".format(output_directory,
                                                validation_idx))
            visualize_positionplusoffset(offset_gt,
                                         "{}/offsets_gt/id{:03}_mean".format(
                                             output_directory, validation_idx),
                                         groundtruth=offset_gt)
            visualize_offsethsv(
                offset_gt,
                "{}/offsets_gt/id{:03}".format(output_directory,
                                               validation_idx))

    print("--- Training.")
    rlosses = []
    closses = []
    for epoch in range(start_epoch, n_epochs):
        model.train()
        total_rloss = 0
        total_closs = 0
        for batch_idx, batch_data in enumerate(train_loader):
            img = batch_data[0].cuda()
            semantic_gt = batch_data[1].cuda()
            instance_offset_gt = batch_data[2].cuda()
            del batch_data

            optimizer.zero_grad()
            outputs = model(img)

            batch_rloss = 0
            batch_closs = 0
            loss = 0
            closs = 0
            rloss = 0
            if regression_loss is not None:
                predicted_offset = outputs[:, -2:]

                rloss = regression_loss(predicted_offset, instance_offset_gt)

                batch_rloss += int(rloss.detach().cpu())
                total_rloss += batch_rloss

                loss += rloss

            if classification_loss is not None:
                closs = classification_loss(outputs[:, :n_classes],
                                            semantic_gt)

                batch_closs += int(closs.detach().cpu())
                total_closs += batch_closs

                loss += closs

            loss.backward()
            optimizer.step()

            if batch_idx % 30 == 0 and batch_idx != 0:
                print('\t[batch {}/{}], [batch mean - closs {:5}, rloss {:5}]'.
                      format(batch_idx, len(train_loader),
                             batch_closs / img.size(0),
                             batch_rloss / img.size(0)))
        del img, semantic_gt, instance_offset_gt, outputs, rloss, closs, loss
        total_closs /= len(train_set)
        total_rloss /= len(train_set)

        print('[epoch {}], [mean train - closs {:5}, rloss {:5}]'.format(
            epoch, total_closs, total_rloss))
        rlosses.append(total_rloss)
        closses.append(total_closs)
        plt.plot(np.arange(start_epoch, epoch + 1), rlosses)
        plt.savefig('{}/rlosses.svg'.format(output_directory))
        plt.close('all')
        plt.plot(np.arange(start_epoch, epoch + 1), closses)
        plt.savefig('{}/closses.svg'.format(output_directory))
        plt.close('all')
        plt.plot(np.arange(start_epoch, epoch + 1), np.add(rlosses, closses))
        plt.savefig('{}/losses.svg'.format(output_directory))
        plt.close('all')

        # --- Visual validation.
        if (epoch % validation_step) == 0:
            # Save model parameters.
            check_mkdir('{}/models'.format(output_directory))
            torch.save(
                model.state_dict(),
                '{}/models/Net_epoch{}.pth'.format(output_directory, epoch))
            torch.save(
                optimizer.state_dict(),
                '{}/models/Adam_epoch{}.pth'.format(output_directory, epoch))

            # Visualize validation imgs.
            check_mkdir('{}/offsets'.format(output_directory))
            check_mkdir('{}/offsets/means'.format(output_directory))
            check_mkdir('{}/semantics'.format(output_directory))
            check_mkdir('{}/semantics/overlay'.format(output_directory))
            model.eval()
            for validation_idx in validation_idxs:
                img_pil, _, _ = val_set.load_PIL_images(validation_idx)
                img, _, offset_gt = val_set[validation_idx]
                img = img.unsqueeze(0).cuda()
                with torch.no_grad():
                    outputs = model(img)
                epoch_filename = 'id{:03}_epoch{:05}'\
                                 .format(validation_idx, epoch)
                if classification_loss is not None:
                    visualize_semantics(
                        img_pil, outputs,
                        "{}/semantics/{}".format(output_directory,
                                                 epoch_filename),
                        "{}/semantics/overlay/{}".format(
                            output_directory, epoch_filename))
                if regression_loss is not None:
                    visualize_offsethsv(
                        outputs.detach(),
                        "{}/offsets/{}".format(output_directory,
                                               epoch_filename))
                    visualize_positionplusoffset(outputs,
                                                 "{}/offsets/means/{}".format(
                                                     output_directory,
                                                     epoch_filename),
                                                 groundtruth=offset_gt)
Ejemplo n.º 14
0
def get_dataset(opts):
    """ Dataset And Augmentation
    """
    if opts.dataset == 'voc':
        train_transform = et.ExtCompose([
            #et.ExtResize(size=opts.crop_size),
            et.ExtRandomScale((0.5, 2.0)),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size),
                             pad_if_needed=True),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])
        if opts.crop_val:
            val_transform = et.ExtCompose([
                et.ExtResize(opts.crop_size),
                et.ExtCenterCrop(opts.crop_size),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        else:
            val_transform = et.ExtCompose([
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
        train_dst = VOCSegmentation(root=opts.data_root,
                                    year=opts.year,
                                    image_set='train',
                                    download=opts.download,
                                    transform=train_transform)
        val_dst = VOCSegmentation(root=opts.data_root,
                                  year=opts.year,
                                  image_set='val',
                                  download=False,
                                  transform=val_transform)

    if opts.dataset == 'cityscapes':
        train_transform = et.ExtCompose([
            #et.ExtResize( 512 ),
            et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
            et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
            et.ExtRandomHorizontalFlip(),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        val_transform = et.ExtCompose([
            #et.ExtResize( 512 ),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]),
        ])

        train_dst = Cityscapes(root=opts.data_root,
                               split='train',
                               transform=train_transform)
        val_dst = Cityscapes(root=opts.data_root,
                             split='val',
                             transform=val_transform)

    if opts.dataset == 'weedcluster':
        train_dst = WeedClusterDataset(root=opts.data_root, split='train')
        val_dst = WeedClusterDataset(root=opts.data_root, split='val')

    if opts.dataset == 'cloudshadow':
        train_dst = CloudShadowDataset(root=opts.data_root, split='train')
        val_dst = CloudShadowDataset(root=opts.data_root, split='val')

    if opts.dataset == 'doubleplant':
        train_dst = DoublePlantDataset(root=opts.data_root, split='train')
        val_dst = DoublePlantDataset(root=opts.data_root, split='val')

    if opts.dataset == 'planterskip':
        train_dst = PlanterSkipDataset(root=opts.data_root, split='train')
        val_dst = PlanterSkipDataset(root=opts.data_root, split='val')

    if opts.dataset == 'standingwater':
        train_dst = StandingWaterDataset(root=opts.data_root, split='train')
        val_dst = StandingWaterDataset(root=opts.data_root, split='val')

    if opts.dataset == 'waterway':
        train_dst = WaterwayDataset(root=opts.data_root, split='train')
        val_dst = WaterwayDataset(root=opts.data_root, split='val')

    return train_dst, val_dst
Ejemplo n.º 15
0
def main():
    create_exp_dir(config.save,
                   scripts_to_save=glob.glob('*.py') + glob.glob('*.sh'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    logging.info("args = %s", str(config))
    # preparation ################
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    seed = config.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # config network and criterion ################
    min_kept = int(config.batch_size * config.image_height *
                   config.image_width // (16 * config.gt_down_sampling**2))

    # data loader ###########################
    data_setting = {
        'img_root': config.img_root_folder,
        'gt_root': config.gt_root_folder,
        'train_source': config.train_source,
        'eval_source': config.eval_source,
        'down_sampling': config.down_sampling
    }

    # Model #######################################
    models = []
    evaluators = []
    lasts = []
    for idx, arch_idx in enumerate(config.arch_idx):
        if config.load_epoch == "last":
            state = torch.load(
                os.path.join(config.load_path, "arch_%d.pt" % arch_idx))
        else:
            state = torch.load(
                os.path.join(
                    config.load_path,
                    "arch_%d_%d.pt" % (arch_idx, int(config.load_epoch))))

        model = Network([
            state["alpha_%d_0" % arch_idx].detach(),
            state["alpha_%d_1" % arch_idx].detach(),
            state["alpha_%d_2" % arch_idx].detach()
        ], [
            None, state["beta_%d_1" % arch_idx].detach(),
            state["beta_%d_2" % arch_idx].detach()
        ], [
            state["ratio_%d_0" % arch_idx].detach(),
            state["ratio_%d_1" % arch_idx].detach(),
            state["ratio_%d_2" % arch_idx].detach()
        ],
                        num_classes=config.num_classes,
                        layers=config.layers,
                        Fch=config.Fch,
                        width_mult_list=config.width_mult_list,
                        stem_head_width=config.stem_head_width[idx],
                        ignore_skip=arch_idx == 0)

        mIoU02 = state["mIoU02"]
        latency02 = state["latency02"]
        obj02 = objective_acc_lat(mIoU02, latency02)
        mIoU12 = state["mIoU12"]
        latency12 = state["latency12"]
        obj12 = objective_acc_lat(mIoU12, latency12)
        if obj02 > obj12:
            last = [2, 0]
        else:
            last = [2, 1]
        lasts.append(last)
        model.build_structure(last)
        # logging.info("net: " + str(model))
        for b in last:
            if len(config.width_mult_list) > 1:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        width=getattr(model, "widths%d" % b),
                        head_width=config.stem_head_width[idx][1],
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
            else:
                plot_op(getattr(model, "ops%d" % b),
                        getattr(model, "path%d" % b),
                        F_base=config.Fch).savefig(os.path.join(
                            config.save, "ops_%d_%d.png" % (arch_idx, b)),
                                                   bbox_inches="tight")
        plot_path_width(model.lasts, model.paths, model.widths).savefig(
            os.path.join(config.save, "path_width%d.png" % arch_idx))
        plot_path_width([2, 1, 0], [model.path2, model.path1, model.path0],
                        [model.widths2, model.widths1, model.widths0]).savefig(
                            os.path.join(config.save,
                                         "path_width_all%d.png" % arch_idx))
        flops, params = profile(model,
                                inputs=(torch.randn(1, 3, 1024, 2048), ),
                                verbose=False)
        logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9)
        logging.info("ops:" + str(model.ops))
        logging.info("path:" + str(model.paths))
        logging.info("last:" + str(model.lasts))
        model = model.cuda()
        init_weight(model,
                    nn.init.kaiming_normal_,
                    torch.nn.BatchNorm2d,
                    config.bn_eps,
                    config.bn_momentum,
                    mode='fan_in',
                    nonlinearity='relu')

        partial = torch.load(
            os.path.join(config.eval_path, "weights%d.pt" % arch_idx))
        state = model.state_dict()
        pretrained_dict = {k: v for k, v in partial.items() if k in state}
        state.update(pretrained_dict)
        model.load_state_dict(state)

        evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None),
                                 config.num_classes,
                                 config.image_mean,
                                 config.image_std,
                                 model,
                                 config.eval_scale_array,
                                 config.eval_flip,
                                 0,
                                 out_idx=0,
                                 config=config,
                                 verbose=False,
                                 save_path=os.path.join(
                                     config.save, 'predictions'),
                                 show_image=True,
                                 show_prediction=True)
        evaluators.append(evaluator)
        models.append(model)

    # Cityscapes ###########################################
    logging.info(config.load_path)
    logging.info(config.eval_path)
    logging.info(config.save)
    with torch.no_grad():
        # validation
        print("[validation...]")
        valid_mIoUs = infer(models, evaluators, logger=None)
        for idx, arch_idx in enumerate(config.arch_idx):
            if arch_idx == 0:
                logging.info("teacher's valid_mIoU %.3f" % (valid_mIoUs[idx]))
            else:
                logging.info("student's valid_mIoU %.3f" % (valid_mIoUs[idx]))
Ejemplo n.º 16
0
def main(model_name, initial_validation):
    # TODO: parse args.
    # --- Tunables.
    # 32GB DRNDSOffsetDisparity, cropped -> 18
    # 12GB DRNDSOffsetDisparity, cropped -> 6
    # 12GB DRNOffsetDisparity, cropped -> 4
    # 12GB DRNOffsetDisparity, original -> 3
    # 12GB DRNDSOffsetDisparity, original -> not supported yet:
    # resize is based on resolution 1792x784
    batch_size = 6  # 6
    n_workers = 21
    n_semantic_pretrain = 0  # 500 # First train only on semantics.
    n_epochs = 500
    validation_step = 5
    train_split = 'subtrain'  # 'train'
    val_split = 'subtrainval'  # 'val'
    validate_on_train = False  # Note: this doesn't include semantic performance.
    train_set_length = 24  # 24 # None
    #cityscapes_directory = "/home/thehn/cityscapes/original"
    cityscapes_directory = "/home/thehn/cityscapes/cropped_cityscapes"
    #cityscapes_directory = "/data/Cityscapes"
    drn_name = 'drn_d_22'  # 'drn_d_22' 'drn_d_38'
    weights = None
    if 'SL' in model_name:
        weights = {
            'offset_mean_weight': 1e-5,  #1e-3
            'offset_variance_weight': 1e-4,  # 1e-3
            'disparity_mean_weight': 1e-7,  #1e-3
            'disparity_variance_weight': 1e-4
        }  # 1e-3
    output_directory = "tmp/train/{}".format(model_name)
    #output_directory = "tmp/train/{}_{}"\
    #                    .format(model_name, time.strftime('%m%d-%H%M'))
    #output_directory = "tmp/train_test"
    #output_directory = "tmp/train_combined"
    #raise ValueError("Please set the input/output directories.")
    print("batch_size =", batch_size)
    print("train_split =", train_split)
    print("val_split =", val_split)
    print(locals())

    check_mkdir(output_directory)

    checkpoint = None

    check_mkdir(output_directory)

    checkpoint = None
    #checkpoint = (
    #        "/home/thomashehn/Code/box2pix/tmp/train/models/Net_epoch6.pth",
    #        "/home/thomashehn/Code/box2pix/tmp/train/models/Adam_epoch6.pth",
    #        6)

    n_classes = 19
    mdl = ModelWrapper(model_name, n_classes, weights, drn_name)

    #for param in parameters:
    #    param.require_grad = False
    #parameters = []
    # weight_decay=1e-6 seems to work so far, but would need more finetuning
    optimizer = optim.Adam(mdl.parameters, weight_decay=1e-6)
    start_epoch = 1

    if checkpoint is not None:
        print("Loading from checkpoint {}".format(checkpoint))
        mdl.NN.load_state_dict(torch.load(checkpoint[0]))
        optimizer.load_state_dict(torch.load(checkpoint[1]))
        start_epoch = checkpoint[2] + 1

    mdl.NN, optimizer = amp.initialize(mdl.NN, optimizer, opt_level="O1")
    # O0, DRNDSDoubleSegSL, bs 6, cropped, 2 epochs -> 11949MB memory, time real 19m34.788s
    # O1, DRNDSDoubleSegSL, bs 6, cropped, 2 epochs -> 7339MB memory, time real 10m32.431s

    # O0, DRNDSOffsetDisparity, bs 6, cropped, 2 epochs -> 11875MB memory, time real 18m13.491s
    # O1, DRNDSOffsetDisparity, bs 6, cropped, 2 epochs -> 7259MB memory, time real 8m51.849s
    # O0, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> memory error
    # O1, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> 8701MB memory, time real 9m13.947s
    # O2, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> 8721MB memory, time real 9m8.563s
    # O3, DRNDSOffsetDisparity, bs 7, cropped, 2 epochs -> 8693MB memory, time real 9m7.476s

    print("--- Setup dataset and dataloaders.")
    mdl.train_set =\
            Cityscapes(mdl.types,
                       data_split=train_split,
                       length=train_set_length,
                       cityscapes_directory=cityscapes_directory,
                       **mdl.dataset_kwargs)
    element = mdl.train_set[0]
    mdl.train_loader = data.DataLoader(mdl.train_set,
                                       batch_size=batch_size,
                                       pin_memory=True,
                                       num_workers=n_workers,
                                       shuffle=True)

    if not validate_on_train:
        mdl.val_set = Cityscapes(mdl.types,
                                 data_split=val_split,
                                 cityscapes_directory=cityscapes_directory,
                                 **mdl.val_dataset_kwargs)
    else:
        mdl.val_set =\
                Cityscapes(mdl.types,
                           data_split=train_split,
                           length=train_set_length,
                           cityscapes_directory=cityscapes_directory,
                           **mdl.val_dataset_kwargs)
    mdl.val_loader = data.DataLoader(mdl.val_set,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=n_workers)

    # Sample 10 validation indices for visualization.
    #validation_idxs = np.random.choice(np.arange(len(val_set)),
    #                                   size=min(9, len(val_set)),
    #                                   replace=False)
    # Nah, let's pick them ourselves for now.
    #validation_idxs = [ 17, 241, 287, 304, 123,
    #                   458,   1,  14, 139, 388]
    validation_idxs = [17, 1, 14]
    #validation_idxs = [ 53,  11,  77]

    metrics = {
        'train': {
            'classification': [],
            'regression': [],
            'epochs': []
        },
        'validation': {
            'classification': [],
            'regression': [],
            'semantic': [],
            'epochs': []
        },
        'memory': {
            'max_cached': [torch.cuda.max_memory_cached()],
            'max_alloc': [torch.cuda.max_memory_allocated()]
        }
    }

    if initial_validation:
        print("--- Setup visual validation.")
        model_file = mdl.save_model(output_directory, suffix="e0000")
        mdl.validation_visual(validation_idxs, output_directory, epoch=0)
        semantic_score =\
                mdl.validation_snapshot(model_file,
                                        path.join(output_directory,
                                                  'last_prediction'),
                                        cityscapes_directory,
                                        batch_size, val_split)
        train_losses = mdl.compute_loss(mdl.train_loader)
        val_losses = mdl.compute_loss(mdl.val_loader, separate=True)
        print('Training loss: {:5} (c) + {:5} (r) = {:5}'.format(
            train_losses[0], train_losses[1], sum(train_losses)))
        if len(val_losses) > 5:
            val_dict = {
                'offset_mean_loss': [val_losses[2]],
                'offset_variance_loss': [val_losses[3]],
                'disparity_mean_loss': [val_losses[4]],
                'disparity_variance_loss': [val_losses[5]]
            }
        print('Validation loss: {:5} (c) + {:5} (r) = {:5}'.format(
            val_losses[0], val_losses[1], sum(val_losses[:2])))
        metrics = {
            'train': {
                'classification': [train_losses[0]],
                'regression': [train_losses[1]],
                'epochs': [start_epoch - 1]
            },
            'validation': {
                'classification': [val_losses[0]],
                'regression': [val_losses[1]],
                'semantic': [semantic_score],
                'epochs': [start_epoch - 1],
                **val_dict
            },
            'memory': {
                'max_cached': [torch.cuda.max_memory_cached()],
                'max_alloc': [torch.cuda.max_memory_allocated()]
            }
        }

    print("--- Training.")
    # First train semantic loss for a while.
    #~regression_loss_stash = None
    #~if n_semantic_pretrain > 0 and regression_loss is not None:
    #~    regression_loss_stash = regression_loss
    #~    regression_loss = None
    #upscale = lambda x: nn.functional.interpolate(x,
    #                        scale_factor=2,
    #                        mode='bilinear',
    #                        align_corners=True)
    for epoch in range(start_epoch, n_epochs + 1):
        #~if epoch >= n_semantic_pretrain and regression_loss_stash is not None:
        #~    regression_loss = regression_loss_stash
        #~    regression_loss_stash = None

        #~if epoch == 10 and False:
        #~    parameters = model.parameters()
        #~    model.train_all = True
        #~    optimizer = optim.Adam(parameters)

        mdl.NN.train()
        total_rloss = 0
        total_closs = 0

        t_sum_batch = 0
        t_sum_opt = 0
        for batch_idx, batch_data in enumerate(mdl.train_loader):
            optimizer.zero_grad()

            batch_losses = mdl.batch_loss(batch_data)

            loss = sum(batch_losses)
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            #loss.backward()
            optimizer.step()

            if batch_idx % 30 == 0 and batch_idx != 0:
                print('\t[batch {}/{}], [batch mean - closs {:5}, rloss {:5}]'.
                      format(batch_idx, len(mdl.train_loader),
                             float(batch_losses[0]) / batch_data[0].size(0),
                             float(batch_losses[1]) / batch_data[0].size(0)))

            total_closs += float(batch_losses[0])
            total_rloss += float(batch_losses[1])
        del loss, batch_data, batch_losses

        total_closs /= len(mdl.train_set)
        total_rloss /= len(mdl.train_set)

        print('[epoch {}], [mean train - closs {:5}, rloss {:5}]'.format(
            epoch, total_closs, total_rloss))
        metrics['train']['classification'].append(total_closs)
        metrics['train']['regression'].append(total_rloss)
        metrics['train']['epochs'].append(epoch)

        metrics['memory']['max_cached'].append(torch.cuda.max_memory_cached())
        metrics['memory']['max_alloc'].append(
            torch.cuda.max_memory_allocated())

        # --- Visual validation.
        if (epoch % validation_step) == 0:
            print("--- Validation.")
            mdl.validation_visual(validation_idxs, output_directory, epoch)

            model_file = mdl.save_model(output_directory,
                                        suffix="{:04}".format(epoch))
            metrics['validation']['semantic'].append(
                mdl.validation_snapshot(
                    model_file, path.join(output_directory, 'last_prediction'),
                    cityscapes_directory, batch_size, val_split))

            val_losses = mdl.compute_loss(mdl.val_loader, separate=True)
            if len(val_losses) > 5:
                if 'offset_mean_loss' not in metrics['validation'].keys():
                    val_dict = {
                        'offset_mean_loss': [val_losses[2]],
                        'offset_variance_loss': [val_losses[3]],
                        'disparity_mean_loss': [val_losses[4]],
                        'disparity_variance_loss': [val_losses[5]]
                    }
                    metrics['validation'] = {
                        **metrics['validation'],
                        **val_dict
                    }
                else:
                    metrics['validation']['offset_mean_loss']\
                            .append(val_losses[2])
                    metrics['validation']['offset_variance_loss']\
                            .append(val_losses[3])
                    metrics['validation']['disparity_mean_loss']\
                            .append(val_losses[4])
                    metrics['validation']['disparity_variance_loss']\
                            .append(val_losses[5])
                print('Separate validation losses: {:5}, {:5}, {:5}, {:5}'.
                      format(*val_losses[2:]))

            metrics['validation']['classification'].append(val_losses[0])
            metrics['validation']['regression'].append(val_losses[1])
            metrics['validation']['epochs'].append(epoch)
            print('Validation loss: {:5} (c) + {:5} (r) = {:5}'.format(
                val_losses[0], val_losses[1], sum(val_losses[:2])))

        # --- Write losses to disk.
        with open(path.join(output_directory, "metrics.json"), 'w') as outfile:
            json.dump(metrics, outfile)
        for key in metrics.keys():
            data_set = key
            set_metrics = metrics[data_set]
            plot_losses(set_metrics, "{}/{}".format(output_directory,
                                                    data_set))
        sin_outputs = F.softmax(model(img.to(device)),
                                dim=1).detach().cpu().numpy()

        # 融合
        alpha = 0.5
        preds = np.concatenate(preds, 1)
        if dataset.lower() != 'cityscapes':
            background_p = np.expand_dims(sin_outputs[:, 0, :, :],
                                          axis=1)  # 抽取单类模型预测的背景概率
            preds = np.concatenate((background_p, preds), 1)  # 单类模型与多雷模型分数融合
        final_preds = alpha * preds + sin_outputs
        preds = np.argmax(final_preds, axis=1)
        if dataset == 'voc':
            pred = voc_cmap()[preds.squeeze(axis=0)].astype(np.uint8)
        else:
            pred = Cityscapes.decode_target(preds.squeeze(axis=0)).astype(
                np.uint8)
        Image.fromarray(pred).save(result_path)
        print("Prediction is saved in %s" % result_path)

    else:
        # 模型加载
        model = model_map[model_name](num_classes=sin_model_class,
                                      output_stride=16)
        weights = torch.load(ckpt_path)["model_state"]
        model.load_state_dict(weights)
        model.to(device)
        model.eval()

        with torch.no_grad():
            print(img.shape)
            img = img.to(device)