Exemplo n.º 1
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    train_logger = make_logger("Train.log", args)
    test_logger = make_logger("Test.log", args)

    model = load_models(
        mode="cls",
        device=device,
        args=args,
    )

    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.999),
    )
    optimizer.zero_grad()

    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    if (args.train or args.run_semi) and args.test:
        print("===================================")
        print("====== Loading Training Data ======")
        print("===================================")

        sample_gt_list = np.load(args.gt_sample_list)

        trainset_gt = ModelNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        print("===================================")
        print("====== Loading Test Data ======")
        print("===================================")
        testset = ModelNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )

        args.iter_per_epoch = int(trainset_gt.__len__() / args.batch_size)
        args.total_data = trainset_gt.__len__()
        args.total_iterations = int(args.num_epochs * args.total_data /
                                    args.batch_size)
        args.iter_save_epoch = args.save_per_epoch * int(
            args.total_data / args.batch_size)
        args.iter_test_epoch = args.test_epoch * int(
            args.total_data / args.batch_size)

    if args.train and args.test:
        model.train()
        model.to(args.device)

        cls_loss = torch.nn.CrossEntropyLoss().to(device)

        trainloader_gt_iter = enumerate(trainloader_gt)

        run_training_pointnet_cls(
            trainloader_gt=trainloader_gt,
            trainloader_gt_iter=trainloader_gt_iter,
            testloader=testloader,
            model=model,
            cls_loss=cls_loss,
            optimizer=optimizer,
            writer=writer,
            train_logger=train_logger,
            test_logger=test_logger,
            args=args,
        )

    if args.test:
        print("===================================")
        print("====== Loading Testing Data =======")
        print("===================================")
        testset = ModelNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )
        criterion = torch.nn.CrossEntropyLoss().to(device)

        run_testing(
            dataloader=testloader,
            model=model,
            criterion=criterion,
            logger=test_logger,
            test_iter=100000000,
            writer=None,
            args=args,
        )
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(os.path.join(RES_DIR, args.name),
                                "evaluation_trinity_test")
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)
    print("EXP PATH: {}".format(args.exp_dir))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    assert args.checkpoint_seg is not None, "Need trained .pth!"
    model_seg = load_models(
        mode="segmentation",
        device=device,
        args=args,
    )
    assert args.checkpoint_disc is not None, "Need trained .pth!"
    model_disc = load_models(
        mode="single_discriminator",
        device=device,
        args=args,
    )

    transforms_shape_target = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])
    photo_transformer = PhotometricTransform(photometric_transform_config)

    traindata = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "train"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=transforms_shape_target,
        photo_transforms=photo_transformer,
        train_bool=False,
    )
    train_loader = torch.utils.data.DataLoader(
        traindata,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )
    trainloader_iter = enumerate(train_loader)
    args.total_iterations = traindata.__len__() // args.batch_size

    model_seg.eval()
    model_disc.eval()

    confidence_map = []
    confidence_mean_score = []
    confidence_cnt = []
    imagename = []
    imageindex = []

    dst_folder = os.path.join(args.exp_dir, "visualization")
    if not os.path.isdir(dst_folder):
        os.mkdir(dst_folder)

    image_confidence = np.empty((traindata.__len__(), 224, 224, 2))
    # pred_all = np.empty((traindata.__len__(),4,224,224))
    # with open("dataloaders/eye/trinity_top_200.pkl", "rb") as f:
    #     top_200_lists = pickle.load(f)

    with open("dataloaders/eye/top_1_adv.pkl", "rb") as f:
        top_1_lists = pickle.load(f)

    with open("dataloaders/eye/top_2_adv.pkl", "rb") as f:
        top_2_lists = pickle.load(f)

    for i_iter in range(args.total_iterations):
        if i_iter % 1000 == 0:
            print("Processing {} ..........".format(i_iter))

        # if not (traindata.train_data_list[i_iter] in top_200_lists):
        #     continue

        # if traindata.train_data_list[i_iter] in top_1_lists:
        #     continue

        if traindata.train_data_list[i_iter] in top_2_lists:
            continue

        imageindex.append(i_iter)
        imagename.append(traindata.train_data_list[i_iter])
        _, batch = next(trainloader_iter)

        images, labels = batch
        images = Variable(images).to(args.device)
        labels = Variable(labels.long()).to(args.device)

        pred = model_seg(images)
        pred_softmax = F.softmax(pred, dim=1)
        D_out = model_disc(pred_softmax)
        D_out = torch.sigmoid(D_out)
        D_out = D_out[0, 0, :, :].detach().cpu().numpy()

        pred = np.argmax(pred.detach().cpu().numpy(), axis=1)[0, :, :]
        # fig = plt.figure()
        # ax = fig.add_subplot(231)
        # ax.imshow(images[0,0,:,:].detach().cpu().numpy(), cmap="gray")
        # ax.set_xticks([])
        # ax.set_yticks([])
        #
        # ax = fig.add_subplot(232)
        # ax.imshow(labels[0, :, :].detach().cpu().numpy())
        # ax.set_xticks([])
        # ax.set_yticks([])
        #
        # ax = fig.add_subplot(233)
        # ax.imshow(pred, cmap="gray")
        # ax.set_xticks([])
        # ax.set_yticks([])
        #
        # ax = fig.add_subplot(234)
        # ax.imshow(D_out, cmap="gray")
        # ax.set_xticks([])
        # ax.set_yticks([])

        D_out_mean = D_out.mean()
        D_out_mean_map = (D_out > D_out_mean) * 1

        # labels = labels[0,:,:].detach().cpu().numpy()
        # semi_ignore_mask = (D_out < D_out_mean)
        # # pseudo_gt = labels.copy()
        # # pseudo_gt[semi_ignore_mask] = 4
        # # pseudo_gt = pseudo_gt.astype(np.uint8)
        # filename = traindata.train_data_list[i_iter].replace("/images/", "/masks/")
        # filename = filename.replace("/train/", "/train_pseudo/")
        # filename = filename.replace(".png", ".npy")
        # np.save(filename, semi_ignore_mask)
        # # print(D_out_mean_map.shape)

        # ax = fig.add_subplot(235)
        # ax.imshow(D_out_mean_map)
        # ax.set_xticks([])
        # ax.set_yticks([])
        # plt.tight_layout()
        # filename = ntpath.basename(traindata.train_data_list[i_iter])
        # filename = os.path.join(dst_folder, filename)
        # plt.savefig(filename)

        # im_filename = traindata.train_data_list[i_iter].replace("/train/", "/train_pseudo/")
        # os.system("cp %s %s" % (traindata.train_data_list[i_iter], im_filename))

        confidence_mean_score.append(D_out_mean)
        confidence_cnt.append(D_out_mean_map.sum())

        ### generate confidence map ###
        # confidence_map.append(D_out_mean)
        # imagename.append(ntpath.basename(traindata.train_data_list[i_iter]))
        # image_confidence[i_iter,:,:,0] = images[0,0,:,:].detach().cpu().numpy()
        # image_confidence[i_iter,:,:,1] = D_out
        # pred_all[i_iter, ...] = pred_softmax[0,...].detach().cpu().numpy()
        ### generate confidence map ###

    # with open("%s/confidence_map_top1_adv.pkl" % (args.exp_dir), "wb") as f:
    #     pickle.dump([imageindex, imagename, confidence_mean_score, confidence_cnt], f)

    with open("%s/confidence_map_top2_adv.pkl" % (args.exp_dir), "wb") as f:
        pickle.dump(
            [imageindex, imagename, confidence_mean_score, confidence_cnt], f)
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)
    print("EXP PATH: {}".format(args.exp_dir))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    logger = make_logger(filename="TrainVal.log", args=args)
    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    print("===================================")
    print("====== Loading Training Data ======")
    print("===================================")
    shape_transformer = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])

    photo_transformer = PhotometricTransform(photometric_transform_config)

    source_data = UnityDataset(
        root=args.source_root,
        image_size=args.image_size,
        shape_transforms=shape_transformer,
        photo_transforms=photo_transformer,
        train_bool=False,
    )

    args.tot_source = source_data.__len__()
    args.total_iterations = args.num_epochs * args.tot_source // args.batch_size
    args.iters_to_eval = args.epoch_to_eval * args.tot_source // args.batch_size

    train_loader = torch.utils.data.DataLoader(
        source_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )

    model_seg = load_models(
        mode="segmentation",
        device=device,
        args=args,
    )
    optimizer_seg = optim.Adam(
        model_seg.parameters(),
        lr=args.lr_seg,
        betas=(args.beta1, 0.999),
    )
    optimizer_seg.zero_grad()

    seg_loss_target = torch.nn.CrossEntropyLoss().to(device)

    trainloader_iter = enumerate(train_loader)

    loss_seg_min = float("inf")
    miou_max = float("-inf")

    for i_iter in range(args.total_iterations):
        loss_seg_value = 0
        model_seg.train()
        optimizer_seg.zero_grad()

        adjust_learning_rate(
            optimizer=optimizer_seg,
            learning_rate=args.lr_seg,
            i_iter=i_iter,
            max_steps=args.total_iterations,
            power=0.9,
        )

        try:
            _, batch = next(trainloader_iter)
        except StopIteration:
            trainloader_iter = enumerate(train_loader)
            _, batch = next(trainloader_iter)

        images, labels = batch
        images = Variable(images).to(args.device)
        labels = Variable(labels.long()).to(args.device)

        pred = model_seg(images)
        loss_seg = seg_loss_target(pred, labels)

        current_loss_seg = loss_seg.item()
        loss_seg_value += current_loss_seg

        loss_seg.backward()
        optimizer_seg.step()

        pred_img = pred.argmax(dim=1, keepdim=True)
        flat_pred = pred_img.detach().cpu().numpy().flatten()
        flat_gt = labels.detach().cpu().numpy().flatten()
        miou, _ = compute_mean_iou(flat_pred=flat_pred, flat_label=flat_gt)

        logger.info('iter = {0:8d}/{1:8d} '
                    'loss_seg = {2:.3f} '
                    'mIoU = {3:.3f} '.format(
                        i_iter,
                        args.total_iterations,
                        loss_seg_value,
                        miou,
                    ))

        if args.tensorboard and (writer != None):
            writer.add_scalar('Train/Cross_Entropy', current_loss_seg, i_iter)
            writer.add_scalar('Train/mIoU', miou, i_iter)

        if i_iter % args.iters_to_eval == 0:
            filename = os.path.join(
                args.exp_dir, "Target_img_trainiter_{}.png".format(i_iter))
            target_img = images.float()
            gen_target_img = torchvision.utils.make_grid(target_img,
                                                         padding=2,
                                                         normalize=True)
            torchvision.utils.save_image(gen_target_img, filename)

            filename = os.path.join(
                args.exp_dir, "Unity_pred_trainiter_{}.png".format(i_iter))
            pred_img = pred_img.float()
            gen_img = torchvision.utils.make_grid(pred_img,
                                                  padding=2,
                                                  normalize=True)
            torchvision.utils.save_image(gen_img, filename)

        is_better_ss = current_loss_seg < loss_seg_min
        if is_better_ss:
            loss_seg_min = current_loss_seg
            torch.save(model_seg.state_dict(),
                       os.path.join(args.exp_dir, "model_train_best.pth"))
        if miou > miou_max:
            miou_max = miou
            torch.save(model_seg.state_dict(),
                       os.path.join(args.exp_dir, "model_train_best_miou.pth"))

    logger.info("==========================================")
    logger.info("Training DONE!")

    if args.tensorboard and (writer != None):
        writer.close()
Exemplo n.º 4
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    train_logger = make_logger("Train.log", args)
    test_logger = make_logger("Test.log", args)

    model = load_models(
        mode="seg",
        device=device,
        args=args,
    )

    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.999),
    )
    optimizer.zero_grad()

    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    if args.train and args.test:
        print("===================================")
        print("====== Loading Training Data ======")
        print("===================================")

        if args.gt_sample_list != None:
            sample_gt_list = np.load(args.gt_sample_list)
        else:
            sample_gt_list = None

        trainset_gt = ShapeNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        print("===================================")
        print("====== Loading Test Data ======")
        print("===================================")
        testset = ShapeNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
            num_classes=16,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )

        args.iter_per_epoch = int(trainset_gt.__len__() / args.batch_size)
        args.total_data = trainset_gt.__len__()
        args.total_iterations = int(args.num_epochs * args.total_data /
                                    args.batch_size)
        args.iter_save_epoch = args.save_per_epoch * int(
            args.total_data / args.batch_size)
        args.iter_test_epoch = args.test_epoch * int(
            args.total_data / args.batch_size)

    if args.train and args.test:
        model.train()
        model.to(args.device)

        seg_loss = torch.nn.CrossEntropyLoss().to(device)

        trainloader_gt_iter = enumerate(trainloader_gt)

        run_training_pointnet_seg(
            trainloader_gt=trainloader_gt,
            trainloader_gt_iter=trainloader_gt_iter,
            testloader=testloader,
            testdataset=testset,
            model=model,
            seg_loss=seg_loss,
            optimizer=optimizer,
            writer=writer,
            train_logger=train_logger,
            test_logger=test_logger,
            args=args,
        )

    if args.test:
        print("===================================")
        print("====== Loading Testing Data =======")
        print("===================================")
        testset = ShapeNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
            num_classes=16,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )
        criterion = torch.nn.CrossEntropyLoss().to(device)

        run_testing_seg(
            dataloader=testloader,
            dataset=testset,
            model=model,
            criterion=criterion,
            logger=test_logger,
            test_iter=100000000,
            writer=None,
            args=args,
        )

    if args.tsne:
        from utils.metric import batch_get_iou, object_names
        from torch.autograd import Variable

        args.batch_size = 1

        labels = []
        objects = []

        if args.gt_sample_list != None:
            sample_gt_list = np.load(args.gt_sample_list)
        else:
            sample_gt_list = None

        trainset_gt = ShapeNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainset_nogt = ShapeNetDataset_noGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )
        trainloader_nogt = torch.utils.data.DataLoader(
            trainset_nogt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        # print("===================================")
        # print("====== Loading Test Data ======")
        # print("===================================")
        # testset = ShapeNetDatasetGT(
        #     root_list=args.test_file,
        #     sample_list=None,
        #     num_classes=16,
        # )
        # testloader = torch.utils.data.DataLoader(
        #     testset,
        #     batch_size=args.batch_size,
        #     shuffle=False,
        #     num_workers=args.workers,
        #     pin_memory=True,
        # )
        #
        # print("===================================")
        # print("====== Loading Testing Data =======")
        # print("===================================")
        # testset = ShapeNetDatasetGT(
        #     root_list=args.test_file,
        #     sample_list=None,
        #     num_classes=16,
        # )
        # testloader = torch.utils.data.DataLoader(
        #     testset,
        #     batch_size=args.batch_size,
        #     shuffle=False,
        #     num_workers=args.workers,
        #     pin_memory=True,
        # )

        model.eval()

        shape_ious = np.empty(len(object_names), dtype=np.object)
        for i in range(shape_ious.shape[0]):
            shape_ious[i] = []

        all_shapes_train_gt = np.empty((len(trainset_gt), 2048))
        for batch_idx, data in enumerate(trainloader_gt):
            if batch_idx % 1000 == 0:
                print("Processing {} ...".format(batch_idx))

            pts, cls, seg = data
            pts, cls, seg = Variable(pts).float(), \
                            Variable(cls), Variable(seg).type(torch.LongTensor)
            pts, cls, seg = pts.to(args.device), cls.to(
                args.device), seg.long().to(args.device)

            labels.append(1)
            objects.append(int(cls.argmax(axis=2).squeeze().cpu().numpy()))

            with torch.set_grad_enabled(False):
                pred, global_shape = model(pts, cls)

            all_shapes_train_gt[
                batch_idx, :] = global_shape.squeeze().detach().cpu().numpy()

        all_shapes_train_nogt = np.empty((len(trainset_nogt), 2048))
        for batch_idx, data in enumerate(trainloader_nogt):
            if batch_idx % 1000 == 0:
                print("Processing {} ...".format(batch_idx))

            pts, cls = data
            pts, cls = Variable(pts).float(), Variable(cls)
            pts, cls = pts.to(args.device), cls.to(args.device)

            with torch.set_grad_enabled(False):
                pred, global_shape = model(pts, cls)

            all_shapes_train_nogt[
                batch_idx, :] = global_shape.squeeze().detach().cpu().numpy()

            labels.append(0)
            objects.append(int(cls.argmax(axis=2).squeeze().cpu().numpy()))

        all_shapes = np.concatenate(
            (all_shapes_train_gt, all_shapes_train_nogt), axis=0)

        shape_info = {
            "shapes": all_shapes,
            "labels": labels,
            "objects": objects
        }

        import pickle
        try:
            o = open("global_shape_{}.pkl".format(len(trainset_gt)), "wb")
            pickle.dump(shape_info, o, protocol=2)
            o.close()
        except FileNotFoundError as e:
            print(e)
Exemplo n.º 5
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(os.path.join(RES_DIR, args.name),
                                "evaluation_trinity_test")
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)
    print("EXP PATH: {}".format(args.exp_dir))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    uncertainty_scores = np.load("uncertainty_scores.npy")
    with (open(
            "results/unity_to_trinity_UDA_semi_10/evaluation_trinity_test/confidence_map_UDA.pkl",
            "rb")) as f:
        adv_scores = pickle.load(f)
    index = np.argsort(adv_scores[1])[::-1]
    ascore = np.empty((2, 8916))
    ascore[0] = index
    ascore[1] = adv_scores[1][index]

    assert args.checkpoint_seg is not None, "Need trained .pth!"
    model_seg = load_models(
        mode="segmentation",
        device=device,
        args=args,
    )

    # assert args.checkpoint_disc is not None, "Need trained .pth!"
    # model_disc = load_models(
    #     mode="single_discriminator",
    #     device=device,
    #     args=args,
    # )

    class FeatureExtractor(torch.nn.Module):
        def __init__(self, submodule, extracted_layers):
            super(FeatureExtractor, self).__init__()
            self.submodule = submodule
            self.extracted_layers = extracted_layers

        def forward(self, x):
            for name, module in self.submodule._modules.items():
                x = module(x)
                print(name)
                if name in self.extracted_layers:
                    return x['x5']

    exact_list = ["pretrained_net"]
    featExactor = FeatureExtractor(model_seg, exact_list)
    # a = torch.randn(1, 3, 224, 224)
    # a = Variable(a).to(args.device)
    # x = myexactor(a)
    # print(x)

    transforms_shape_target = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])
    photo_transformer = PhotometricTransform(photometric_transform_config)

    traindata = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "train"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=transforms_shape_target,
        photo_transforms=photo_transformer,
        train_bool=False,
    )
    train_loader = torch.utils.data.DataLoader(
        traindata,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )
    trainloader_iter = enumerate(train_loader)
    args.total_iterations = traindata.__len__() // args.batch_size

    model_seg.eval()
    feature_maps = np.empty((traindata.__len__(), 25088))

    count = 0
    for i_iter in range(args.total_iterations):
        if i_iter % 1000 == 0:
            print("Processing {} ..........".format(i_iter))

        # if uncertainty_scores[1][i_iter] == uncertainty_scores[1].max():
        #     print("uncertain", i_iter, traindata.train_data_list[i_iter], uncertainty_scores[1][i_iter], adv_scores[1][i_iter])
        #     count += 1
        # if adv_scores[1][i_iter] == adv_scores[1].max():
        #     print("adv", i_iter, traindata.train_data_list[i_iter], uncertainty_scores[1][i_iter], adv_scores[1][i_iter])
        #     count += 1
        #
        # if count==2:
        #     break

        # u_idx = np.where(uncertainty_scores[0]==i_iter)[0]
        a_idx = np.where(ascore[0] == i_iter)[0]
        _, batch = next(trainloader_iter)
        images, labels = batch
        images = Variable(images).to(args.device)
        # labels = Variable(labels.long()).to(args.device)

        feat = featExactor(images)
        # feature_maps[i_iter] = feat.view(1,-1)[0].detach().cpu().numpy()
        feature_maps[a_idx] = feat.view(1, -1)[0].detach().cpu().numpy()

    # print("Saving feature maps .......................")
    # np.save("feature_maps_UDA_original_order.npy", feature_maps)

    # print("Saving feature maps .......................")
    # np.save("feature_maps_UDA_ascore.npy", feature_maps)

    from sklearn.metrics.pairwise import cosine_similarity
    dist = cosine_similarity(feature_maps)

    #
    # A = np.matmul(feature_maps.transpose(), feature_maps)
    # D = A.diagonal()
    # distance_map = np.power(D, 0.5) * A * np.power(D, -0.5)
    #
    print("Saving distance maps .......................")
    np.save("distance_maps_UDA_ascore.npy", dist)
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)
    print("EXP PATH: {}".format(args.exp_dir))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    logger = make_logger(filename="TrainVal.log", args=args)
    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    print("===================================")
    print("====== Loading Training Data ======")
    print("===================================")
    shape_transformer = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])

    photo_transformer = PhotometricTransform(photometric_transform_config)

    joint_data = JointDataset(
        root_source=args.source_root,
        image_size=args.image_size,
        data_to_train="dataloaders/eye/trinity_train_200.pkl",
        shape_transforms=shape_transformer,
        photo_transforms=photo_transformer,
        train_bool=False,
    )

    args.tot_data = joint_data.__len__()
    args.total_iterations = args.num_epochs * args.tot_data // args.batch_size
    args.iters_to_eval = args.epoch_to_eval * args.tot_data // args.batch_size

    print("===================================")
    print("========= Loading Val Data ========")
    print("===================================")
    val_target_data = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "validation"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=shape_transformer,
        photo_transforms=None,
        train_bool=False,
    )

    train_loader = torch.utils.data.DataLoader(
        joint_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )

    val_loader = torch.utils.data.DataLoader(
        val_target_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )

    model_seg = load_models(
        mode="segmentation",
        device=device,
        args=args,
    )
    optimizer_seg = optim.Adam(
        model_seg.parameters(),
        lr=args.lr_seg,
        betas=(args.beta1, 0.999),
    )
    optimizer_seg.zero_grad()

    # class_weight_target = 1.0 / train_target_data.get_class_probability().to(device)
    seg_loss_target = torch.nn.CrossEntropyLoss().to(device)

    trainloader_iter = enumerate(train_loader)

    val_loss, val_miou = [], []
    val_loss_f = float("inf")
    val_miou_f = float("-inf")
    loss_seg_min = float("inf")

    for i_iter in range(args.total_iterations):
        loss_seg_value = 0
        model_seg.train()
        optimizer_seg.zero_grad()

        adjust_learning_rate(
            optimizer=optimizer_seg,
            learning_rate=args.lr_seg,
            i_iter=i_iter,
            max_steps=args.total_iterations,
            power=0.9,
        )

        try:
            _, batch = next(trainloader_iter)
        except StopIteration:
            trainloader_iter = enumerate(train_loader)
            _, batch = next(trainloader_iter)

        images, labels = batch
        images = Variable(images).to(args.device)
        labels = Variable(labels.long()).to(args.device)

        pred = model_seg(images)
        loss_seg = seg_loss_target(pred, labels)

        current_loss_seg = loss_seg.item()
        loss_seg_value += current_loss_seg

        loss_seg.backward()
        optimizer_seg.step()

        logger.info('iter = {0:8d}/{1:8d} '
                    'loss_seg = {2:.3f} '.format(
                        i_iter,
                        args.total_iterations,
                        loss_seg_value,
                    ))

        current_epoch = i_iter * args.batch_size // args.tot_data
        if i_iter % args.iters_to_eval == 0:
            val_loss_f, val_miou_f = validate_baseline(
                i_iter=i_iter,
                val_loader=val_loader,
                model=model_seg,
                epoch=current_epoch,
                logger=logger,
                writer=writer,
                val_loss=val_loss_f,
                val_iou=val_miou_f,
                args=args,
            )

            val_loss.append(val_loss_f)
            val_loss_f = np.min(np.array(val_loss))
            val_miou.append(val_miou_f)
            val_miou_f = np.max(np.array(val_miou))

            if args.tensorboard and (writer != None):
                writer.add_scalar('Val/Cross_Entropy_Target', val_loss_f,
                                  i_iter)
                writer.add_scalar('Val/mIoU_Target', val_miou_f, i_iter)

        is_better_ss = current_loss_seg < loss_seg_min
        if is_better_ss:
            loss_seg_min = current_loss_seg
            torch.save(model_seg.state_dict(),
                       os.path.join(args.exp_dir, "model_train_best.pth"))

    logger.info("==========================================")
    logger.info("Training DONE!")

    if args.tensorboard and (writer != None):
        writer.close()

    with open("%s/train_performance.pkl" % args.exp_dir, "wb") as f:
        pickle.dump([val_loss, val_miou], f)
    logger.info("==========================================")
    logger.info("Evaluating on test data ...")

    testdata = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "test"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=shape_transformer,
        photo_transforms=None,
        train_bool=False,
    )
    test_loader = torch.utils.data.DataLoader(
        testdata,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    pm = run_testing(
        dataset=testdata,
        test_loader=test_loader,
        model=model_seg,
        args=args,
    )

    logger.info('Global Mean Accuracy: {:.3f}'.format(np.array(pm.GA).mean()))
    logger.info('Mean IOU: {:.3f}'.format(np.array(pm.IOU).mean()))
    logger.info('Mean Recall: {:.3f}'.format(np.array(pm.Recall).mean()))
    logger.info('Mean Precision: {:.3f}'.format(np.array(pm.Precision).mean()))
    logger.info('Mean F1: {:.3f}'.format(np.array(pm.F1).mean()))

    IOU_ALL = np.array(pm.Iou_all)
    logger.info(
        "Back: {:.4f}, Sclera: {:.4f}, Iris: {:.4f}, Pupil: {:.4f}".format(
            IOU_ALL[:, 0].mean(),
            IOU_ALL[:, 1].mean(),
            IOU_ALL[:, 2].mean(),
            IOU_ALL[:, 3].mean(),
        ))
Exemplo n.º 7
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    train_logger = make_logger("Train.log", args)
    test_logger = make_logger("Test.log", args)

    model = load_models(
        mode="cls",
        device=device,
        args=args,
    )
    model_D = load_models(
        mode="disc",
        device=device,
        args=args,
    )

    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.999),
    )
    optimizer.zero_grad()

    optimizer_D = optim.Adam(
        model_D.parameters(),
        lr=args.lr_D,
        betas=(0.9, 0.999),
    )
    optimizer_D.zero_grad()

    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    if (args.train or args.run_semi) and args.test:
        print("===================================")
        print("====== Loading Training Data ======")
        print("===================================")
        # idx = np.arange(9840)
        # np.random.shuffle(idx)
        # sample_gt_list = idx[0:args.num_samples]
        # sample_nogt_list = idx[args.num_samples:]
        # filename = "gt_sample_{}.npy".format(args.name)
        # np.save(filename, sample_gt_list)

        sample_gt_list = np.load(args.gt_sample_list)

        trainset_gt = ModelNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
        )
        trainset_nogt = ModelNetDataset_noGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )
        trainloader_nogt = torch.utils.data.DataLoader(
            trainset_nogt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        print("===================================")
        print("====== Loading Test Data ======")
        print("===================================")
        testset = ModelNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )

        args.iter_per_epoch = int(1.0 * trainset_gt.__len__() / args.batch_size)
        args.total_data = trainset_gt.__len__() + trainset_nogt.__len__()
        args.total_iterations = int(args.num_epochs *
                                    args.total_data / args.batch_size)
        args.iter_save_epoch = args.save_per_epoch * int(args.total_data / args.batch_size)
        args.iter_test_epoch = args.test_epoch * int(args.total_data / args.batch_size)
        args.semi_start = int(args.semi_start_epoch *
                              args.total_data / args.batch_size)

    if args.train and args.test:
        model.train()
        model_D.train()

        model.to(args.device)
        model_D.to(args.device)

        #class_weight = 1.0 / trainset_gt.get_class_probability().to(device)

        cls_loss = torch.nn.CrossEntropyLoss().to(device)
        gan_loss = torch.nn.BCEWithLogitsLoss().to(device)
        semi_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

        history_pool_gt = ImagePool(args.pool_size)
        history_pool_nogt = ImagePool(args.pool_size)

        trainloader_gt_iter = enumerate(trainloader_gt)
        targetloader_nogt_iter = enumerate(trainloader_nogt)


        if args.semi_start_epoch==0:
            run_training(
                trainloader_gt=trainloader_gt,
                trainloader_nogt=trainloader_nogt,
                trainloader_gt_iter=trainloader_gt_iter,
                targetloader_nogt_iter=targetloader_nogt_iter,
                testloader=testloader,
                model=model,
                model_D=model_D,
                gan_loss=gan_loss,
                cls_loss=cls_loss,
                optimizer=optimizer,
                optimizer_D=optimizer_D,
                history_pool_gt=history_pool_gt,
                history_pool_nogt=history_pool_nogt,
                writer=writer,
                train_logger=train_logger,
                test_logger=test_logger,
                args=args,
            )
        else:
            run_training_semi(
                trainloader_gt=trainloader_gt,
                trainloader_nogt=trainloader_nogt,
                trainloader_gt_iter=trainloader_gt_iter,
                targetloader_nogt_iter=targetloader_nogt_iter,
                testloader=testloader,
                model=model,
                model_D=model_D,
                gan_loss=gan_loss,
                cls_loss=cls_loss,
                semi_loss=semi_loss,
                optimizer=optimizer,
                optimizer_D=optimizer_D,
                history_pool_gt=history_pool_gt,
                history_pool_nogt=history_pool_nogt,
                writer=writer,
                train_logger=train_logger,
                test_logger=test_logger,
                args=args,
            )

    if args.test:
        print("===================================")
        print("====== Loading Testing Data =======")
        print("===================================")
        testset = ModelNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )
        criterion = torch.nn.CrossEntropyLoss().to(device)

        run_testing(
            dataloader=testloader,
            model=model,
            criterion=criterion,
            logger=test_logger,
            test_iter=100000000,
            writer=None,
            args=args,
        )
Exemplo n.º 8
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)
    print("EXP PATH: {}".format(args.exp_dir))

    assert args.trinity_data_train_with_labels != "", "Indicate trained data in Trinity!"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    logger = make_logger(filename="TrainVal.log", args=args)
    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    print("===================================")
    print("====== Loading Training Data ======")
    print("===================================")

    transforms_shape_source = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])

    transforms_shape_target = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])

    photo_transformer = PhotometricTransform(photometric_transform_config)

    source_data = UnityDataset(
        root=args.source_root,
        image_size=args.image_size,
        shape_transforms=transforms_shape_source,
        photo_transforms=photo_transformer,
        train_bool=False,
    )

    target_data = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "train"),
        image_size=args.image_size,
        data_to_train=args.trinity_data_train_with_labels,
        shape_transforms=transforms_shape_target,
        photo_transforms=photo_transformer,
        train_bool=False,
    )

    args.tot_source = source_data.__len__()
    args.total_iterations = args.num_epochs * source_data.__len__(
    ) // args.batch_size
    args.iters_to_eval = args.epoch_to_eval * source_data.__len__(
    ) // args.batch_size
    args.iter_source_to_eval = args.epoch_to_eval_source * source_data.__len__(
    ) // args.batch_size
    args.iter_semi_start = args.epoch_semi_start * source_data.__len__(
    ) // args.batch_size

    print("===================================")
    print("========= Loading Val Data ========")
    print("===================================")
    val_target_data = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "validation"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=transforms_shape_target,
        photo_transforms=None,
        train_bool=False,
    )
    # class_weight_source = 1.0 / source_data.get_class_probability().to(device)

    source_loader = torch.utils.data.DataLoader(
        source_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )
    target_loader = torch.utils.data.DataLoader(
        target_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_target_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )

    model_seg = load_models(
        mode="segmentation",
        device=device,
        args=args,
    )

    model_disc = load_models(
        mode="single_discriminator",
        device=device,
        args=args,
    )

    optimizer_seg = optim.Adam(
        model_seg.parameters(),
        lr=args.lr_seg,
        betas=(args.beta1, 0.999),
    )
    optimizer_seg.zero_grad()

    optimizer_disc = torch.optim.Adam(
        model_disc.parameters(),
        lr=args.lr_disc,
        betas=(args.beta1, 0.999),
    )
    optimizer_disc.zero_grad()

    seg_loss_source = torch.nn.CrossEntropyLoss().to(device)
    gan_loss = torch.nn.BCEWithLogitsLoss().to(device)
    semi_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)

    history_true_mask = ImagePool(args.pool_size)
    history_fake_mask = ImagePool(args.pool_size)

    trainloader_iter = enumerate(source_loader)
    targetloader_iter = enumerate(target_loader)

    val_loss, val_miou = run_training_SSDA(
        trainloader_source=source_loader,
        trainloader_target=target_loader,
        trainloader_iter=trainloader_iter,
        targetloader_iter=targetloader_iter,
        val_loader=val_loader,
        model_seg=model_seg,
        model_disc=model_disc,
        gan_loss=gan_loss,
        seg_loss=seg_loss_source,
        semi_loss_criterion=semi_loss,
        optimizer_seg=optimizer_seg,
        optimizer_disc=optimizer_disc,
        history_pool_true=history_true_mask,
        history_pool_fake=history_fake_mask,
        logger=logger,
        writer=writer,
        args=args,
    )

    with open("%s/train_performance.pkl" % args.exp_dir, "wb") as f:
        pickle.dump([val_loss, val_miou], f)

    logger.info("==========================================")
    logger.info("Evaluating on test data ...")

    testdata = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "test"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=transforms_shape_target,
        photo_transforms=None,
        train_bool=False,
    )
    test_loader = torch.utils.data.DataLoader(
        testdata,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    pm = run_testing(
        dataset=testdata,
        test_loader=test_loader,
        model=model_seg,
        args=args,
    )

    logger.info('Global Mean Accuracy: {:.3f}'.format(np.array(pm.GA).mean()))
    logger.info('Mean IOU: {:.3f}'.format(np.array(pm.IOU).mean()))
    logger.info('Mean Recall: {:.3f}'.format(np.array(pm.Recall).mean()))
    logger.info('Mean Precision: {:.3f}'.format(np.array(pm.Precision).mean()))
    logger.info('Mean F1: {:.3f}'.format(np.array(pm.F1).mean()))

    IOU_ALL = np.array(pm.Iou_all)
    logger.info(
        "Back: {:.4f}, Sclera: {:.4f}, Iris: {:.4f}, Pupil: {:.4f}".format(
            IOU_ALL[:, 0].mean(),
            IOU_ALL[:, 1].mean(),
            IOU_ALL[:, 2].mean(),
            IOU_ALL[:, 3].mean(),
        ))
Exemplo n.º 9
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    test_logger = make_logger("Test.log", args)

    model = load_models(
        mode="cls",
        device=device,
        args=args,
    )
    # model_D = load_models(
    #     mode="disc",
    #     device=device,
    #     args=args,
    # )

    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.999),
    )
    optimizer.zero_grad()

    #optimizer_D = optim.Adam(
    #     model_D.parameters(),
    #     lr=args.lr_D,
    #     betas=(0.9, 0.999),
    # )
    # optimizer_D.zero_grad()

    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    print("===================================")
    print("====== Loading Testing Data =======")
    print("===================================")
    testset = ModelNetDatasetGT(
        root_list=args.test_file,
        sample_list=None,
    )
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )
    criterion = torch.nn.CrossEntropyLoss().to(device)

    train_pth = glob.glob(args.exp_dir + "/model_train_epoch_*.pth")
    print("Total #models: {}".format(len(train_pth)))

    max_accu = float("-inf")
    for i in range(len(train_pth)):
        print("Processing {} .....".format(i))
        print("model {} .........".format(train_pth[i]))
        model.load_state_dict(torch.load(train_pth[i]))

        curr_accu, curr_loss = run_testing(
            dataloader=testloader,
            model=model,
            criterion=criterion,
            logger=test_logger,
            test_iter=100000000,
            writer=None,
            args=args,
        )

        if curr_accu > max_accu:
            max_accu = curr_accu
            max_pth = train_pth[i]

    print("Max accuracy: {:.4f}".format(max_accu))
    print("Trained model: {}".format(max_pth))
Exemplo n.º 10
0
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(RES_DIR, args.name)
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    train_logger = make_logger("Train.log", args)
    test_logger = make_logger("Test.log", args)

    if args.train or args.train_stack:
        if args.train_regu:
            model = load_models(
                mode="seg_regu",
                device=device,
                args=args,
            )
        else:
            model = load_models(
                mode="seg",
                device=device,
                args=args,
            )

        model_D = load_models(
            mode="disc_stack",
            device=device,
            args=args,
        )

        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            betas=(0.9, 0.999),
        )
        optimizer.zero_grad()

        optimizer_D = optim.Adam(
            model_D.parameters(),
            lr=args.lr_D_point,
            betas=(0.9, 0.999),
        )
        optimizer_D.zero_grad()

    if args.tensorboard:
        writer = SummaryWriter(args.exp_dir)
    else:
        writer = None

    if (args.train or args.train_stack) and args.test:
        print("===================================")
        print("====== Loading Training Data ======")
        print("===================================")

        if args.gt_sample_list != None:
            sample_gt_list = np.load(args.gt_sample_list)
        else:
            sample_gt_list = None

        trainset_gt = ShapeNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )
        trainset_nogt = ShapeNetDataset_noGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )
        trainloader_nogt = torch.utils.data.DataLoader(
            trainset_nogt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        print("===================================")
        print("====== Loading Test Data ======")
        print("===================================")
        testset = ShapeNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
            num_classes=16,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )

        args.iter_per_epoch = int(1.0 * trainset_gt.__len__() / args.batch_size)
        args.total_data = trainset_gt.__len__() + trainset_nogt.__len__()
        args.total_iterations = int(args.num_epochs *
                                    args.total_data / args.batch_size)
        args.iter_save_epoch = args.save_per_epoch * int(args.total_data / args.batch_size)
        args.iter_test_epoch = args.test_epoch * int(args.total_data / args.batch_size)
        args.semi_start = int(args.semi_start_epoch *
                               args.total_data / args.batch_size)

    if (args.train or args.train_stack) and args.test:
        model.train()
        model_D.train()

        model.to(args.device)
        model_D.to(args.device)

        #class_weight = 1.0 / trainset_gt.get_class_probability().to(device)

        seg_loss = torch.nn.CrossEntropyLoss().to(device)
        gan_loss = torch.nn.BCEWithLogitsLoss().to(device)
        semi_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

        history_pool_gt = ImagePool(args.pool_size)
        history_pool_nogt = ImagePool(args.pool_size)

        trainloader_gt_iter = enumerate(trainloader_gt)
        targetloader_nogt_iter = enumerate(trainloader_nogt)


        if args.train and (args.semi_start_epoch==0):
            run_training_seg(
                trainloader_gt=trainloader_gt,
                trainloader_nogt=trainloader_nogt,
                trainloader_gt_iter=trainloader_gt_iter,
                targetloader_nogt_iter=targetloader_nogt_iter,
                testloader=testloader,
                testdataset=testset,
                model=model,
                model_D=model_D,
                gan_loss=gan_loss,
                seg_loss=seg_loss,
                optimizer=optimizer,
                optimizer_D=optimizer_D,
                history_pool_gt=history_pool_gt,
                history_pool_nogt=history_pool_nogt,
                writer=writer,
                train_logger=train_logger,
                test_logger=test_logger,
                args=args,
            )
        elif args.train and (args.semi_start_epoch>0):
            run_training_seg_semi(
                trainloader_gt=trainloader_gt,
                trainloader_nogt=trainloader_nogt,
                trainloader_gt_iter=trainloader_gt_iter,
                targetloader_nogt_iter=targetloader_nogt_iter,
                testloader=testloader,
                testdataset=testset,
                model=model,
                model_D=model_D,
                gan_loss=gan_loss,
                seg_loss=seg_loss,
                semi_loss=semi_loss,
                optimizer=optimizer,
                optimizer_D=optimizer_D,
                history_pool_gt=history_pool_gt,
                history_pool_nogt=history_pool_nogt,
                writer=writer,
                train_logger=train_logger,
                test_logger=test_logger,
                args=args,
            )
        elif args.train_stack:
            gan_loss = torch.nn.BCELoss().to(device)
            shape_criterion = torch.nn.CrossEntropyLoss().to(device)
            regu_loss = torch.nn.MSELoss().to(device)

            if args.semi_start_epoch == 0:
                if args.train_regu:
                    run_training_seg_stack_regulization(
                        trainloader_gt=trainloader_gt,
                        trainloader_nogt=trainloader_nogt,
                        trainloader_gt_iter=trainloader_gt_iter,
                        targetloader_nogt_iter=targetloader_nogt_iter,
                        testloader=testloader,
                        testdataset=testset,
                        model=model,
                        model_D=model_D,
                        gan_loss=gan_loss,
                        seg_loss=seg_loss,
                        regu_loss=regu_loss,
                        shape_criterion=shape_criterion,
                        optimizer=optimizer,
                        optimizer_D=optimizer_D,
                        history_pool_gt=history_pool_gt,
                        history_pool_nogt=history_pool_nogt,
                        writer=writer,
                        train_logger=train_logger,
                        test_logger=test_logger,
                        args=args,
                    )
                else:
                    run_training_seg_stack(
                        trainloader_gt=trainloader_gt,
                        trainloader_nogt=trainloader_nogt,
                        trainloader_gt_iter=trainloader_gt_iter,
                        targetloader_nogt_iter=targetloader_nogt_iter,
                        testloader=testloader,
                        testdataset=testset,
                        model=model,
                        model_D=model_D,
                        gan_loss=gan_loss,
                        seg_loss=seg_loss,
                        shape_criterion=shape_criterion,
                        optimizer=optimizer,
                        optimizer_D=optimizer_D,
                        history_pool_gt=history_pool_gt,
                        history_pool_nogt=history_pool_nogt,
                        writer=writer,
                        train_logger=train_logger,
                        test_logger=test_logger,
                        args=args,
                    )
            else:
                if args.train_regu:
                    run_training_seg_stack_regulization_semi(
                        trainloader_gt=trainloader_gt,
                        trainloader_nogt=trainloader_nogt,
                        trainloader_gt_iter=trainloader_gt_iter,
                        targetloader_nogt_iter=targetloader_nogt_iter,
                        testloader=testloader,
                        testdataset=testset,
                        model=model,
                        model_D=model_D,
                        gan_loss=gan_loss,
                        seg_loss=seg_loss,
                        regu_loss=regu_loss,
                        semi_loss=semi_loss,
                        shape_criterion=shape_criterion,
                        optimizer=optimizer,
                        optimizer_D=optimizer_D,
                        history_pool_gt=history_pool_gt,
                        history_pool_nogt=history_pool_nogt,
                        writer=writer,
                        train_logger=train_logger,
                        test_logger=test_logger,
                        args=args,
                    )
                else:
                    run_training_seg_stack_semi(
                        trainloader_gt=trainloader_gt,
                        trainloader_nogt=trainloader_nogt,
                        trainloader_gt_iter=trainloader_gt_iter,
                        targetloader_nogt_iter=targetloader_nogt_iter,
                        testloader=testloader,
                        testdataset=testset,
                        model=model,
                        model_D=model_D,
                        gan_loss=gan_loss,
                        seg_loss=seg_loss,
                        semi_loss=semi_loss,
                        shape_criterion=shape_criterion,
                        optimizer=optimizer,
                        optimizer_D=optimizer_D,
                        history_pool_gt=history_pool_gt,
                        history_pool_nogt=history_pool_nogt,
                        writer=writer,
                        train_logger=train_logger,
                        test_logger=test_logger,
                        args=args,
                    )

    if args.dual:
        model = load_models(
            mode="seg",
            device=device,
            args=args,
        )
        optimizer = optim.Adam(
            model.parameters(),
            lr=args.lr,
            betas=(0.9, 0.999),
        )
        optimizer.zero_grad()

        sharedDisc, shapeDisc, pointDisc = load_models(
            mode="disc_dual",
            device=device,
            args=args,
        )
        D_params_shape = (list(shapeDisc.parameters()) + list(sharedDisc.parameters()))
        D_params_point = (list(pointDisc.parameters()) + list(sharedDisc.parameters()))
        optimizer_D_shape = optim.SGD(
            D_params_shape,
            lr=args.lr_D_shape,
            # lr=args.lr_D*0.2,
        )
        optimizer_D_shape.zero_grad()
        # optimizer_D_point = optim.Adam(
        #     D_params_point,
        #     lr=args.lr_D,
        #     betas=(0.9, 0.999),
        # )
        optimizer_D_point = optim.SGD(
            D_params_point,
            lr=args.lr_D_point,
        )
        optimizer_D_point.zero_grad()

        print("===================================")
        print("====== Loading Training Data ======")
        print("===================================")

        if args.gt_sample_list != None:
            sample_gt_list = np.load(args.gt_sample_list)
        else:
            sample_gt_list = None

        trainset_gt = ShapeNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )
        trainset_nogt = ShapeNetDataset_noGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )
        trainloader_nogt = torch.utils.data.DataLoader(
            trainset_nogt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        print("===================================")
        print("====== Loading Test Data ======")
        print("===================================")
        testset = ShapeNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
            num_classes=16,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )

        args.iter_per_epoch = int(1.0 * trainset_gt.__len__() / args.batch_size)
        args.total_data = trainset_gt.__len__() + trainset_nogt.__len__()
        args.total_iterations = int(args.num_epochs *
                                    args.total_data / args.batch_size)
        args.iter_save_epoch = args.save_per_epoch * int(args.total_data / args.batch_size)
        args.iter_test_epoch = args.test_epoch * int(args.total_data / args.batch_size)
        args.semi_start = int(args.semi_start_epoch *
                              args.total_data / args.batch_size)

        model.train()
        sharedDisc.train()
        shapeDisc.train()
        pointDisc.train()

        model.to(args.device)
        sharedDisc.to(args.device)
        shapeDisc.to(args.device)
        pointDisc.to(args.device)

        seg_loss = torch.nn.CrossEntropyLoss().to(device)
        gan_point_loss = torch.nn.BCEWithLogitsLoss().to(device)
        gan_shape_loss = torch.nn.CrossEntropyLoss().to(device)

        history_pool_gt = ImagePool(args.pool_size)
        history_pool_nogt = ImagePool(args.pool_size)

        trainloader_gt_iter = enumerate(trainloader_gt)
        targetloader_nogt_iter = enumerate(trainloader_nogt)

        run_training_seg_dual(
            trainloader_gt=trainloader_gt,
            trainloader_nogt=trainloader_nogt,
            trainloader_gt_iter=trainloader_gt_iter,
            targetloader_nogt_iter=targetloader_nogt_iter,
            testloader=testloader,
            testdataset=testset,
            model=model,
            sharedDisc=sharedDisc,
            shapeDisc=shapeDisc,
            pointDisc=pointDisc,
            gan_point_loss=gan_point_loss,
            gan_shape_loss=gan_shape_loss,
            seg_loss=seg_loss,
            optimizer=optimizer,
            optimizer_D_shape=optimizer_D_shape,
            optimizer_D_point=optimizer_D_point,
            history_pool_gt=history_pool_gt,
            history_pool_nogt=history_pool_nogt,
            writer=writer,
            train_logger=train_logger,
            test_logger=test_logger,
            args=args,
        )

    if args.test:
        print("===================================")
        print("====== Loading Testing Data =======")
        print("===================================")
        testset = ShapeNetDatasetGT(
            root_list=args.test_file,
            sample_list=None,
            num_classes=16,
        )
        testloader = torch.utils.data.DataLoader(
            testset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
        )
        criterion = torch.nn.CrossEntropyLoss().to(device)

        run_testing_seg(
            dataset=testset,
            dataloader=testloader,
            model=model,
            criterion=criterion,
            logger=test_logger,
            test_iter=100000000,
            writer=None,
            args=args,
        )

    if args.tsne:
        from utils.metric import batch_get_iou, object_names
        from torch.autograd import Variable

        model = load_models(
            mode="seg",
            device=device,
            args=args,
        )

        args.batch_size = 1

        labels = []
        objects = []

        if args.gt_sample_list != None:
            sample_gt_list = np.load(args.gt_sample_list)
        else:
            sample_gt_list = None

        trainset_gt = ShapeNetDatasetGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainset_nogt = ShapeNetDataset_noGT(
            root_list=args.train_file,
            sample_list=sample_gt_list,
            num_classes=16,
        )

        trainloader_gt = torch.utils.data.DataLoader(
            trainset_gt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )
        trainloader_nogt = torch.utils.data.DataLoader(
            trainset_nogt,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True,
        )

        model.eval()

        shape_ious = np.empty(len(object_names), dtype=np.object)
        for i in range(shape_ious.shape[0]):
            shape_ious[i] = []

        all_shapes_train_gt = np.empty((len(trainset_gt), 2048))
        for batch_idx, data in enumerate(trainloader_gt):
            if batch_idx % 1000==0:
                print("Processing {} ...".format(batch_idx))

            pts, cls, seg = data
            pts, cls, seg = Variable(pts).float(), \
                            Variable(cls), Variable(seg).type(torch.LongTensor)
            pts, cls, seg = pts.to(args.device), cls.to(args.device), seg.long().to(args.device)

            labels.append(1)
            objects.append(int(cls.argmax(axis=2).squeeze().cpu().numpy()))

            with torch.set_grad_enabled(False):
                pred, global_shape = model(pts, cls)

            all_shapes_train_gt[batch_idx,:] = global_shape.squeeze().detach().cpu().numpy()

        all_shapes_train_nogt = np.empty((len(trainset_nogt), 2048))
        for batch_idx, data in enumerate(trainloader_nogt):
            if batch_idx % 1000==0:
                print("Processing {} ...".format(batch_idx))

            pts, cls = data
            pts, cls = Variable(pts).float(), Variable(cls)
            pts, cls = pts.to(args.device), cls.to(args.device)

            with torch.set_grad_enabled(False):
                pred, global_shape = model(pts, cls)

            all_shapes_train_nogt[batch_idx, :] = global_shape.squeeze().detach().cpu().numpy()

            labels.append(0)
            objects.append(int(cls.argmax(axis=2).squeeze().cpu().numpy()))

        all_shapes = np.concatenate((all_shapes_train_gt, all_shapes_train_nogt), axis=0)

        shape_info = {"shapes": all_shapes, "labels": labels, "objects": objects}

        import pickle
        try:
            o = open("stack_global_shape_{}.pkl".format(len(trainset_gt)), "wb")
            pickle.dump(shape_info,o,protocol=2)
            o.close()
        except FileNotFoundError as e:
            print (e)
def main(args):
    print('===================================\n', )
    print("Root directory: {}".format(args.name))
    args.exp_dir = os.path.join(os.path.join(RES_DIR, args.name),
                                "evaluation_trinity_test")
    if not os.path.isdir(args.exp_dir):
        os.makedirs(args.exp_dir)
    print("EXP PATH: {}".format(args.exp_dir))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    transforms_shape_target = dual_transforms.Compose([
        dual_transforms.CenterCrop((400, 400)),
        dual_transforms.Scale(args.image_size[0]),
    ])

    testdata = OpenEDSDataset_withLabels(
        root=os.path.join(args.target_root, "test"),
        image_size=args.image_size,
        data_to_train="",
        shape_transforms=transforms_shape_target,
        photo_transforms=None,
        train_bool=False,
    )
    test_loader = torch.utils.data.DataLoader(
        testdata,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    assert args.checkpoint_seg is not None, "Need trained .pth!"
    model_seg = load_models(
        mode="segmentation",
        device=device,
        args=args,
    )

    print("Evaluating ...................")
    miou_all, iou_all = run_testing(
        dataset=testdata,
        test_loader=test_loader,
        model=model_seg,
        args=args,
    )

    # print('Global Mean Accuracy: {:.3f}'.format(np.array(pm.GA).mean()))
    # print('Mean IOU: {:.3f}'.format(np.array(pm.IOU).mean()))
    # print('Mean Recall: {:.3f}'.format(np.array(pm.Recall).mean()))
    # print('Mean Precision: {:.3f}'.format(np.array(pm.Precision).mean()))
    # print('Mean F1: {:.3f}'.format(np.array(pm.F1).mean()))()

    print('Mean IOU: {:.3f}'.format(miou_all.mean()))
    print("Back: {:.4f}, Sclera: {:.4f}, Iris: {:.4f}, Pupil: {:.4f}".format(
        iou_all[:, 0].mean(),
        iou_all[:, 1].mean(),
        iou_all[:, 2].mean(),
        iou_all[:, 3].mean(),
    ))