예제 #1
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--num_classes', type=int, default=130)

    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)

    parser.add_argument('--resume', type=bool, default=False)
    parser.add_argument('--eval', type=bool, default=False)

    parser.add_argument('--dataroot', type=str, default='/content/drive/MyDrive/dogflg/data2/')
    parser.add_argument('--model_path', type=str, default='./best_model.bin')

    parser.add_argument('--sampleratio', type=float, default=0.8)

    args = parser.parse_args()
    
    transform_train = transform.Compose([
        transform.Resize((256, 256)),
        transform.CenterCrop(224),
        transform.RandomHorizontalFlip(),
        transform.ToTensor(),
        transform.ImageNormalize(0.485, 0.229),
        # transform.ImageNormalize(0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    root_dir = args.dataroot
    train_loader = TsinghuaDog(root_dir, batch_size=args.batch_size, train=True, part='train', shuffle=True, transform=transform_train, sample_rate=args.sampleratio)

    transform_test = transform.Compose([
        transform.Resize((256, 256)),
        transform.CenterCrop(224),
        transform.ToTensor(),
        transform.ImageNormalize(0.485, 0.229),
        # transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_loader = TsinghuaDog(root_dir, batch_size=args.batch_size, train=False, part='val', shuffle=False, transform=transform_test, sample_rate=args.sampleratio)

    epochs = args.epochs
    model = Net(num_classes=args.num_classes)
    lr = args.lr
    weight_decay = args.weight_decay
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.99)
    if args.resume:
        model.load(args.model_path)
        print('model loaded', args.model_path)

    #random save for test
    #model.save(args.model_path)
    if args.eval:
        evaluate(model, val_loader, save_path=args.model_path)
        return 
    for epoch in range(epochs):
        train(model, train_loader, optimizer, epoch)
        evaluate(model, val_loader, epoch, save_path=args.model_path)
예제 #2
0
    def test_not_pil_image(self):
        img = jt.random((30, 40, 3))
        result = transform.Compose([
            transform.RandomAffine(20),
            transform.ToTensor(),
        ])(img)

        img = jt.random((30, 40, 3))
        result = transform.Compose([
            transform.ToPILImage(),
            transform.Gray(),
            transform.Resize(20),
            transform.ToTensor(),
        ])(img)
예제 #3
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--num_classes', type=int, default=130)

    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)

    parser.add_argument('--resume', type=bool, default=False)
    parser.add_argument('--eval', type=bool, default=False)

    parser.add_argument('--dataroot', type=str, default='/home/gmh/dataset/TsinghuaDog/')
    parser.add_argument('--model_path', type=str, default='./best_model.pkl')

    args = parser.parse_args()
    
    transform_train = transform.Compose([
        transform.Resize((512, 512)),
        transform.RandomCrop(448),
        transform.RandomHorizontalFlip(),
        transform.ToTensor(),
        transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    root_dir = args.dataroot
    train_loader = TsinghuaDog(root_dir, batch_size=16, train=True, part='train', shuffle=True, transform=transform_train)
    
    transform_test = transform.Compose([
        transform.Resize((512, 512)),
        transform.CenterCrop(448),
        transform.ToTensor(),
        transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_loader = TsinghuaDog(root_dir, batch_size=16, train=False, part='val', shuffle=False, transform=transform_test)
    epochs = args.epochs
    model = Net(num_classes=args.num_classes)
    lr = args.lr
    weight_decay = args.weight_decay
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) 
    if args.resume:
        model.load(args.model_path)
    if args.eval:
        evaluate(model, val_loader)
        return 
    for epoch in range(epochs):
        train(model, train_loader, optimizer, epoch)
        evaluate(model, val_loader, epoch)
예제 #4
0
    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
        owidth = random.randint(5, (width - 2) / 2) * 2

        img = np.ones([height, width, 3])
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        # imgnarrow = img[oh1:oh1 + oheight, ow1:ow1 + owidth, :]
        # imgnarrow.fill(0)
        img[oh1:oh1 + oheight, ow1:ow1 + owidth, :] = 0
        # img = jt.array(img)
        result = transform.Compose([
            transform.ToPILImage(),
            transform.CenterCrop((oheight, owidth)),
            transform.ToTensor(),
        ])(img)
        self.assertEqual(
            result.sum(), 0,
            f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}"
        )
        oheight += 1
        owidth += 1
        result = transform.Compose([
            transform.ToPILImage(),
            transform.CenterCrop((oheight, owidth)),
            transform.ToTensor(),
        ])(img)
        sum1 = result.sum()
        # TODO: not pass
        # self.assertGreater(sum1, 1,
        #                    f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}")
        oheight += 1
        owidth += 1
        result = transform.Compose([
            transform.ToPILImage(),
            transform.CenterCrop((oheight, owidth)),
            transform.ToTensor(),
        ])(img)
        sum2 = result.sum()
        self.assertGreater(
            sum2, 0,
            f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}"
        )
        self.assertGreaterEqual(
            sum2, sum1,
            f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}"
        )
    def build_transform(self):
        """
        Creates a basic transformation that was used to train the models
        """
        cfg = self.cfg

        # we are loading images with OpenCV, so we don't need to convert them
        # to BGR, they are already! So all we need to do is to normalize
        # by 255 if we want to convert to BGR255 format, or flip the channels
        # if we want it to be in RGB in [0-1] range.
        if cfg.INPUT.TO_BGR255:
            to_bgr_transform = T.Lambda(lambda x: x * 255)
        else:
            to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])

        normalize_transform = T.ImageNormalize(
            mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
        )
        min_size = cfg.INPUT.MIN_SIZE_TEST
        max_size = cfg.INPUT.MAX_SIZE_TEST
        transform = T.Compose(
            [
                T.ToPILImage(),
                Resize(min_size, max_size),
                T.ToTensor(),
                to_bgr_transform,
                normalize_transform,
            ]
        )
        return transform
예제 #6
0
 def test_TenCrop(self):
     img = jt.random((30, 40, 3))
     result = transform.Compose([
         transform.ToPILImage(),
         transform.TenCrop(20),
         transform.ToTensor(),
     ])(img)
예제 #7
0
def transforms_imagenet_train(
        img_size=224,
        scale=None,
        ratio=None,
        hflip=0.5,
        vflip=0.,
        interpolation='random',
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
):
    """
    If separate==True, the transforms are returned as a tuple of 3 separate transforms
    for use in a mixing dataset that passes
     * all data through the first (primary) transform, called the 'clean' data
     * a portion of the data through the secondary transform
     * normalizes and converts the branches above with the third, final transform
    """
    scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range
    ratio = tuple(ratio or (3. / 4., 4. / 3.))  # default imagenet ratio range
    primary_tfl = [
        RandomResizedCropAndInterpolation(img_size,
                                          scale=scale,
                                          ratio=ratio,
                                          interpolation=interpolation)
    ]
    if hflip > 0.:
        primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
    if vflip > 0.:
        primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]

    final_tfl = [
        transforms.ToTensor(),
        transforms.ImageNormalize(mean=mean, std=std)
    ]
    return transforms.Compose(primary_tfl + final_tfl)
예제 #8
0
    def test_1_channel_tensor_to_pil_image(self):
        to_tensor = transform.ToTensor()
        shape = (4, 4, 1)

        img_data_float = jt.array(np.random.rand(*shape), dtype='float32')
        img_data_byte = jt.array(np.random.randint(0, 255, shape),
                                 dtype='uint8')
        img_data_short = jt.array(np.random.randint(0, 32767, shape),
                                  dtype='int16')
        img_data_int = jt.array(np.random.randint(0, 2147483647, shape),
                                dtype='int32')

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [
            img_data_float.multiply(255).int().float().divide(255).numpy(),
            img_data_byte.float().divide(255.0).numpy(),
            img_data_short.numpy(),
            img_data_int.numpy()
        ]
        expected_modes = ['F', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs,
                                                   expected_modes):
            for t in [transform.ToPILImage(), transform.ToPILImage(mode=mode)]:
                img = t(img_data)
                self.assertEqual(img.mode, mode)
                np.testing.assert_allclose(expected_output[:, :, 0],
                                           to_tensor(img)[0],
                                           atol=0.01)
        # 'F' mode for torch.FloatTensor
        img_F_mode = transform.ToPILImage(mode='F')(img_data_float)
        self.assertEqual(img_F_mode.mode, 'F')
예제 #9
0
def get_loader(root_dir,
               label_file,
               batch_size,
               img_size=0,
               num_thread=4,
               pin=True,
               test=False,
               split='train'):
    if test is False:
        raise NotImplementedError
    else:
        transform = transforms.Compose([
            transforms.Resize((400, 400)),
            transforms.ToTensor(),
            transforms.ImageNormalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
        ])
        dataset = SemanLineDatasetTest(root_dir,
                                       label_file,
                                       transform=transform,
                                       t_transform=None)
    if test is False:
        raise NotImplementedError
    else:
        dataset.set_attrs(batch_size=batch_size, shuffle=False)
        print('Get dataset success.')
    return dataset
예제 #10
0
 def test_RandomPerspective(self):
     img = jt.random((30, 40, 3))
     result = transform.Compose([
         transform.ToPILImage(),
         transform.RandomPerspective(p=1),
         transform.ToTensor(),
     ])(img)
예제 #11
0
 def test_RandomAffine(self):
     img = jt.random((30, 40, 3))
     result = transform.Compose([
         transform.ToPILImage(),
         transform.RandomAffine(20),
         transform.ToTensor(),
     ])(img)
예제 #12
0
    def test_2d_tensor_to_pil_image(self):
        to_tensor = transform.ToTensor()

        img_data_float = jt.array(np.random.rand(4, 4), dtype='float32')
        img_data_byte = jt.array(np.random.randint(0, 255, (4, 4)),
                                 dtype='uint8')
        img_data_short = jt.array(np.random.randint(0, 32767, (4, 4)),
                                  dtype='int16')
        img_data_int = jt.array(np.random.randint(0, 2147483647, (4, 4)),
                                dtype='int32')

        inputs = [img_data_float, img_data_byte, img_data_short, img_data_int]
        expected_outputs = [
            img_data_float.multiply(255).int().float().divide(255).numpy(),
            img_data_byte.float().divide(255.0).numpy(),
            img_data_short.numpy(),
            img_data_int.numpy()
        ]
        expected_modes = ['F', 'L', 'I;16', 'I']

        for img_data, expected_output, mode in zip(inputs, expected_outputs,
                                                   expected_modes):
            for t in [transform.ToPILImage(), transform.ToPILImage(mode=mode)]:
                img = t(img_data)
                self.assertEqual(img.mode, mode)
                self.assertTrue(
                    np.allclose(expected_output,
                                to_tensor(img),
                                atol=0.01,
                                rtol=0.01))
예제 #13
0
    def test_resize(self):
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        osize = random.randint(5, 12) * 2

        img = jt.ones([height, width, 3])
        result = transform.Compose([
            transform.ToPILImage(),
            transform.Resize(osize),
            transform.ToTensor(),
        ])(img)
        self.assertIn(osize, result.shape)
        if height < width:
            self.assertLessEqual(result.shape[1], result.shape[2])
        elif width < height:
            self.assertGreaterEqual(result.shape[1], result.shape[2])

        result = transform.Compose([
            transform.ToPILImage(),
            transform.Resize([osize, osize]),
            transform.ToTensor(),
        ])(img)
        self.assertIn(osize, result.shape)
        self.assertEqual(result.shape[1], osize)
        self.assertEqual(result.shape[2], osize)

        oheight = random.randint(5, 12) * 2
        owidth = random.randint(5, 12) * 2
        result = transform.Compose([
            transform.ToPILImage(),
            transform.Resize((oheight, owidth)),
            transform.ToTensor(),
        ])(img)
        self.assertEqual(result.shape[1], oheight)
        self.assertEqual(result.shape[2], owidth)

        result = transform.Compose([
            transform.ToPILImage(),
            transform.Resize([oheight, owidth]),
            transform.ToTensor(),
        ])(img)
        self.assertEqual(result.shape[1], oheight)
        self.assertEqual(result.shape[2], owidth)
예제 #14
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--num_classes', type=int, default=130)

    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)

    parser.add_argument('--resume', type=bool, default=True)
    parser.add_argument('--eval', type=bool, default=False)

    parser.add_argument('--dataroot',
                        type=str,
                        default='/content/drive/MyDrive/dogfl/data/TEST_A/')
    parser.add_argument('--model_path', type=str, default='./best_model.bin')
    parser.add_argument('--out_file', type=str, default='./result.json')

    args = parser.parse_args()

    root_dir = args.dataroot

    transform_test = transform.Compose([
        transform.Resize((512, 512)),
        transform.CenterCrop(448),
        transform.ToTensor(),
        transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    name_list = []
    for _, _, _name_list in os.walk(root_dir):
        name_list = _name_list
    val_loader = TsinghuaDogExam(root_dir,
                                 batch_size=args.batch_size,
                                 train=False,
                                 name_list=name_list,
                                 shuffle=False,
                                 transform=transform_test)

    model = Net(num_classes=args.num_classes)
    if args.resume:
        model.load(args.model_path)
        print('model loaded', args.model_path)

    top5_class_list = evaluate(model, val_loader)
    # label start from 1, however it doesn't
    pred_result = dict(zip(name_list, top5_class_list))

    with open(args.out_file, 'w') as fout:
        json.dump(pred_result, fout, ensure_ascii=False, indent=4)
예제 #15
0
def get_transform(new_size=None):
    """
    obtain the image transforms required for the input data
    :param new_size: size of the resized images
    :return: image_transform => transform object from TorchVision
    """
    # from torchvision.transforms import ToTensor, Normalize, Compose, Resize, RandomHorizontalFlip

    if new_size is not None:
        image_transform = transform.Compose([
            transform.RandomHorizontalFlip(),
            transform.Resize(new_size),
            transform.ToTensor(),
            transform.ImageNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

    else:
        image_transform = transform.Compose([
            transform.RandomHorizontalFlip(),
            transform.ToTensor(),
            transform.ImageNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    return image_transform
예제 #16
0
def im_detect_bbox(model, images, target_scale, target_max_size):
    """
    Performs bbox detection on the original image.
    """
    transform = TT.Compose([
        T.Resize(target_scale, target_max_size),
        TT.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN,
                    std=cfg.INPUT.PIXEL_STD,
                    to_bgr255=cfg.INPUT.TO_BGR255)
    ])
    images = [transform(image) for image in images]
    images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
    return model(images)
예제 #17
0
    def test_random_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
        owidth = random.randint(5, (width - 2) / 2) * 2
        img = np.ones((height, width, 3))
        result = transform.Compose([
            transform.ToPILImage(),
            transform.RandomCrop((oheight, owidth)),
            transform.ToTensor(),
        ])(img)
        self.assertEqual(result.shape[1], oheight)
        self.assertEqual(result.shape[2], owidth)

        result = transform.Compose([
            transform.ToPILImage(),
            transform.RandomCrop((oheight, owidth)),
            transform.ToTensor(),
        ])(img)
        self.assertEqual(result.shape[1], oheight)
        self.assertEqual(result.shape[2], owidth)

        result = transform.Compose([
            transform.ToPILImage(),
            transform.RandomCrop((height, width)),
            transform.ToTensor()
        ])(img)
        self.assertEqual(result.shape[1], height)
        self.assertEqual(result.shape[2], width)
        self.assertTrue(np.allclose(img, result.transpose(1, 2, 0)))

        with self.assertRaises(AssertionError):
            result = transform.Compose([
                transform.ToPILImage(),
                transform.RandomCrop((height + 1, width + 1)),
                transform.ToTensor(),
            ])(img)
예제 #18
0
    def build_transform():
        if cfg.INPUT.TO_BGR255:
            to_bgr_transform = T.Lambda(lambda x: x * 255)
        else:
            to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])

        normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN,
                                          std=cfg.INPUT.PIXEL_STD)
        min_size = cfg.INPUT.MIN_SIZE_TEST
        max_size = cfg.INPUT.MAX_SIZE_TEST
        transform = T.Compose([
            T.ToPILImage(),
            Resize(min_size, max_size),
            T.ToTensor(),
            to_bgr_transform,
            normalize_transform,
        ])
        return transform
예제 #19
0
    def test_to_tensor(self):
        test_channels = [1, 3, 4]
        height, width = 4, 4
        trans = transform.ToTensor()

        with self.assertRaises(TypeError):
            trans(np.random.rand(1, height, width).tolist())

        with self.assertRaises(ValueError):
            trans(np.random.rand(height))
            trans(np.random.rand(1, 1, height, width))

        for channels in test_channels:
            input_data = np.random.randint(
                low=0, high=255, size=(height, width, channels)).astype(
                    np.float32) / np.float32(255.0)
            img = transform.ToPILImage()(input_data)
            output = trans(img)
            expect = input_data.transpose(2, 0, 1)
            self.assertTrue(np.allclose(expect, output),
                            f"{expect.shape}\n{output.shape}")

            ndarray = np.random.randint(low=0,
                                        high=255,
                                        size=(channels, height,
                                              width)).astype(np.uint8)
            output = trans(ndarray)
            expected_output = ndarray / 255.0
            np.testing.assert_allclose(output, expected_output)

            ndarray = np.random.rand(channels, height,
                                     width).astype(np.float32)
            output = trans(ndarray)
            expected_output = ndarray
            self.assertTrue(np.allclose(output, expected_output))

        # separate test for mode '1' PIL images
        input_data = np.random.binomial(1, 0.5, size=(height, width,
                                                      1)).astype(np.uint8)
        img = transform.ToPILImage()(input_data * 255).convert('1')
        output = trans(img)
        self.assertTrue(np.allclose(input_data[:, :, 0], output[0]),
                        f"{input_data.shape}\n{output.shape}")
예제 #20
0
def im_detect_bbox_hflip(model, images, target_scale, target_max_size):
    """
    Performs bbox detection on the horizontally flipped image.
    Function signature is the same as for im_detect_bbox.
    """
    transform = TT.Compose([
        T.Resize(target_scale, target_max_size),
        TT.RandomHorizontalFlip(1.0),
        TT.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN,
                    std=cfg.INPUT.PIXEL_STD,
                    to_bgr255=cfg.INPUT.TO_BGR255)
    ])
    images = [transform(image) for image in images]
    images = to_image_list(images, cfg.DATALOADER.SIZE_DIVISIBILITY)
    boxlists = model(images)

    # Invert the detections computed on the flipped image
    boxlists_inv = [boxlist.transpose(0) for boxlist in boxlists]
    return boxlists_inv
예제 #21
0
def transforms_imagenet_eval(img_size=224,
                             crop_pct=0.9,
                             interpolation=Image.BICUBIC,
                             mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)):
    crop_pct = crop_pct or 0.875

    if isinstance(img_size, tuple):
        assert len(img_size) == 2
        if img_size[-1] == img_size[-2]:
            # fall-back to older behaviour so Resize scales to shortest edge if target is square
            scale_size = int(math.floor(img_size[0] / crop_pct))
        else:
            scale_size = tuple([int(x / crop_pct) for x in img_size])
    else:
        scale_size = int(math.floor(img_size / crop_pct))
    return transforms.Compose([
        Resize(scale_size, interpolation),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.ImageNormalize(mean=mean, std=std)
    ])
예제 #22
0
def test_dataset():
    root = '/home/gmh/dataset/TsinghuaDog'
    part = 'train'
    # from torchvision import transforms
    rgb_mean = [0.5, 0.5, 0.5]
    rgb_std = [0.5, 0.5, 0.5]

    transform_val = transform.Compose([
        transform.Resize((299, 299)),
        transform.ToTensor(),
        transform.ImageNormalize(rgb_mean, rgb_std),
    ])

    dataloader = TsinghuaDog(root,
                             batch_size=16,
                             train=False,
                             part=part,
                             shuffle=True,
                             transform=transform_val)
    # def __init__(self, root_dir, batch_size, part='train', train=True, shuffle=False, transform=None, num_workers=1):

    for images, labels in dataloader:
        # print(images.size(),labels.size(),labels)
        pass
예제 #23
0
    nsteps = max_step - init_step + 1

    lr = 1e-3
    mixing = True

    code_size = 512
    batch_size = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    batch_default = 32

    phase = 150_000
    max_iter = 100_000

    transform = transform.Compose([
        transform.ToPILImage(),
        transform.RandomHorizontalFlip(),
        transform.ToTensor(),
        transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    netG = StyledGenerator(code_dim=code_size)
    netD = Discriminator(from_rgb_activate=True)
    g_running = StyledGenerator(code_size)
    g_running.eval()

    d_optimizer = jt.optim.Adam(netD.parameters(), lr=lr, betas=(0.0, 0.99))
    g_optimizer = jt.optim.Adam(netG.generator.parameters(),
                                lr=lr,
                                betas=(0.0, 0.99))
    g_optimizer.add_param_group({
        'params': netG.style.parameters(),
        'lr': lr * 0.01,
예제 #24
0
## Split
train_list, valid_list = train_test_split(train_list,
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=42)

print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

## Image Augumentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCropAndResize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


## Load Datasets
class CatsDogsDataset(Dataset):
    def __init__(self,
                 file_list,
                 transform=None,
                 batch_size=1,
                 shuffle=False,
                 num_workers=0):
        super(CatsDogsDataset, self).__init__(batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers)
        self.file_list = file_list
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--num_classes', type=int, default=130)

    parser.add_argument('--lr', type=float, default=2e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)

    parser.add_argument('--resume', type=bool, default=True)
    parser.add_argument('--eval', type=bool, default=False)

    parser.add_argument('--dataroot', type=str, default='/content/drive/MyDrive/dogfl/data/TEST_A/')
    parser.add_argument('--model_path', type=str, default='model/res152_8/model.bin:model/res152_10/model.bin:model/seres152_8/model.bin:model/seres152_10/model.bin')
    parser.add_argument('--model_name', type=str, default='Net1_z:Net1_z:Net10_z:Net10_z')

    parser.add_argument('--out_file', type=str, default='./result.json')


    args = parser.parse_args()

    root_dir = args.dataroot

    transform_test = transform.Compose([
        transform.Resize((256, 256)),
        transform.CenterCrop(224),
        transform.ToTensor(),
        transform.ImageNormalize(0.485, 0.229),
    ])

    model_zoo = args.model_path.split(':')
    model_name_class = args.model_name.split(':')
    print(model_zoo, model_name_class)
    zoo_pred_result = []
    i = 0
    for model_path_ in model_zoo:
        name_list = []
        for _, _, _name_list in os.walk(root_dir):
            name_list = _name_list
        val_loader = TsinghuaDogExam(root_dir, batch_size=args.batch_size, train=False, name_list=name_list, shuffle=False, transform=transform_test)

        model = eval(model_name_class[i])(num_classes=args.num_classes)
        if args.resume:
            model.load(model_path_)


        top5_class_list = evaluate(model, val_loader)
        # label start from 1, however it doesn't
        pred_result = dict(zip(name_list, top5_class_list))
        zoo_pred_result.append(pred_result)
        i += 1

    # vote the best
    president = zoo_pred_result[0]
    vote_result = {}
    for key in president.keys():
        val_list = []
        for i in range(5):
            candiates = [model[key][i] for model in zoo_pred_result]
            val, n = Counter(candiates).most_common()[0]
            val_list.append(val)
        vote_result[key] = val_list



    with open(args.out_file, 'w') as fout:
        json.dump(vote_result, fout, ensure_ascii=False, indent=4)