コード例 #1
0
def set_model(input_dims_, model_type_, classes_):
    input_shape = input_dims_ + (3, )
    if model_type_ == 'ResNet50':
        model_ = ResNet50(input_shape=input_shape,
                          classes=classes_,
                          name='ResNet50')
    elif model_type_ == 'ResNet50_seq':
        model_ = ResNet50_seq(input_shape=input_shape,
                              classes=classes_,
                              name='ResNet50_seq')
    elif model_type_ == 'ResNet34':
        model_ = ResNet34(input_shape=input_shape,
                          classes=classes_,
                          name='ResNet34')
    elif model_type_ == 'ResNetXX':
        model_ = ResNetXX(input_shape=input_shape,
                          classes=classes_,
                          name='ResNetXX')
    elif model_type_ == 'DG':
        model_ = DG(input_shape=input_shape, classes=classes_, name='DG')

    model_name_ = '_'.join([
        model_.name, 'x'.join(map(str, input_dims_)),
        '_'.join(Faker().name().split(' '))
    ])
    return model_, model_name_
コード例 #2
0
    def configure_model(self):
        """

        :return:
        """
        arch: str = self.hparams.arch
        batch_norm = self.hparams.batch_norm
        dataset: str = self.hparams.dataset
        hidden_layers: int = self.hparams.hidden_layers
        hidden_size: int = self.hparams.hidden_size
        if arch == 'mlp':
            if dataset == 'mnist':
                return MLP(input_size=784,
                           hidden_size=hidden_size,
                           num_hidden_layers=hidden_layers,
                           batch_norm=batch_norm)
            elif dataset == 'cifar10':
                return MLP(hidden_size=hidden_size,
                           num_hidden_layers=hidden_layers,
                           batch_norm=batch_norm)
            else:
                raise ValueError('invalid dataset specification!')
        elif arch == 'alexnet':
            return AlexNet()
        elif arch == 'vgg11':
            return VGG(vgg_name='VGG11')
        elif arch == 'vgg13':
            return VGG(vgg_name='VGG13')
        elif arch == 'resnet18':
            return ResNet18()
        elif arch == 'resnet34':
            return ResNet34()
        else:
            raise ValueError('Unsupported model!')
コード例 #3
0
def get_model(name: str):
    name = name.lower()
    if name == 'vgg11':
        return VGG('VGG11')
    elif name == 'resnet18':
        return ResNet18()
    elif name == 'resnet34':
        return ResNet34()
    elif name == 'resnet50':
        return ResNet50()
コード例 #4
0
def main(credentials: Dict, engine_params: Dict) -> Optional:
    model = ResNet34(pretrained=False)
    trainer = BengaliTrainer(model=model, model_name='resnet34')
    bengali = BengaliEngine(trainer=trainer, params=engine_params)
    submission = bengali.run_inference_engine(
        model_dir=engine_params['model_dir'],
        to_csv=True,
        output_dir=engine_params['submission_dir'],
        load_from_s3=True,
        creds=credentials)
コード例 #5
0
    def __init__(self, path_wieght: str, path_data: str, similarity,
                 path_feat: str):
        self.path_weight = path_wieght
        self.path_data = path_data
        self.similarity = similarity
        self.flickr_dataset = ImageFlickrFeatures(
            path_feat)  #dbs/features_contrastive.db
        # self.ranking = ranking

        imagenet_net = ResNet34()
        sketches_net = ResNet34()

        # print("Adapting output layers...")

        siamese_net = SiameseNetwork(sketches_net, imagenet_net)
        siamese_net.load_state_dict(
            torch.load(self.path_weight)
        )  # r'C:\Users\aleja\Desktop\Tareas\Reconocimiento Virtual con Deep Learning\T2\best_SiameseNetwork_contrastive.pth'
        self.net = siamese_net
        self.ranking = Ranker(self.path_data,
                              image_dataset_features=self.flickr_dataset,
                              feature_extractor=self.net,
                              similarity_fn=self.similarity)
コード例 #6
0
def test_attributes():
    model = ResNet34(pretrained=False)
    test_trainer = trainers.BengaliTrainer(model=model)
    assert hasattr(test_trainer, "train")
    assert hasattr(test_trainer, "_loss_fn")
    assert hasattr(test_trainer, "criterion")
    assert hasattr(test_trainer, "optimizer")
    assert hasattr(test_trainer, "scheduler")
    assert hasattr(test_trainer, "early_stopping")
    assert hasattr(test_trainer, "model")
    assert hasattr(test_trainer, "load_model_locally")
    assert hasattr(test_trainer, "load_model_from_s3")
    assert hasattr(test_trainer, "save_model_locally")
    assert hasattr(test_trainer, "save_model_to_s3")
    assert isinstance(test_trainer, BaseTrainer)
コード例 #7
0
def _get_specified_model(model):
    if model == 'ResNet18':
        return ResNet18()
    elif model == 'ResNet34':
        return ResNet34()
    elif model == 'PreActResNet18':
        return PreActResNet18()
    elif model == 'PreActResNet34':
        return PreActResNet34()
    elif model == 'WideResNet28':
        return WideResNet28()
    elif model == 'WideResNet34':
        return WideResNet34()
    else:
        return ResNet18()
コード例 #8
0
def get_model_for_attack(model_name):
    if model_name == 'model1':
        model = ResNet34()
        load_w(model, "./models/weights/resnet34.pt")
    elif model_name == 'model2':
        model = ResNet18()
        load_w(model, "./models/weights/resnet18_AT.pt")
    elif model_name == 'model3':
        model = SmallResNet()
        load_w(model, "./models/weights/res_small.pth")
    elif model_name == 'model4':
        model = WideResNet34()
        pref = next(model.parameters())
        model.load_state_dict(
            filter_state_dict(
                torch.load("./models/weights/trades_wide_resnet.pt",
                           map_location=pref.device)))
    elif model_name == 'model5':
        model = WideResNet()
        load_w(model, "./models/weights/wideres34-10-pgdHE.pt")
    elif model_name == 'model6':
        model = WideResNet28()
        pref = next(model.parameters())
        model.load_state_dict(
            filter_state_dict(
                torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt',
                           map_location=pref.device)))
    elif model_name == 'model_vgg16bn':
        model = vgg16_bn(pretrained=True)
    elif model_name == 'model_resnet18_imgnet':
        model = resnet18(pretrained=True)
    elif model_name == 'model_inception':
        model = inception_v3(pretrained=True)
    elif model_name == 'model_vitb':
        from mnist_vit import ViT, MegaSizer
        model = MegaSizer(
            ImageNetRenormalize(ViT('B_16_imagenet1k', pretrained=True)))
    elif model_name.startswith('model_hub:'):
        _, a, b = model_name.split(":")
        model = torch.hub.load(a, b, pretrained=True)
        model = Cifar10Renormalize(model)
    elif model_name.startswith('model_mnist:'):
        _, a = model_name.split(":")
        model = torch.load('mnist.pt')[a]
    elif model_name.startswith('model_ex:'):
        _, a = model_name.split(":")
        model = torch.load(a)
    return model
コード例 #9
0
def get_model_for_defense(model_name):
    if 'WRN28' in model_name:
        model = WideResNet28()
    elif 'PreActRN18' in model_name:
        model = PreActResNet18()
    elif 'PreActRN34' in model_name:
        model = PreActResNet34()
    elif 'RN18' in model_name:
        model = ResNet18()
    elif 'RN34' in model_name:
        model = ResNet34()
    else:
        raise ValueError(
            f'Unsupported model name: {model_name}. Check your spelling!')
    checkpoint = torch.load(f"./models/weights/{model_name}.pt")
    model.load_state_dict(checkpoint['model'])

    return model
コード例 #10
0
ファイル: model.py プロジェクト: jiangyangzhou/defense_base
def get_model_for_attack(model_name):
    if model_name=='model1':
        model = ResNet34()
        model.load_state_dict(torch.load("models/weights/resnet34.pt"))
    elif model_name=='model2':
        model = ResNet18()
        model.load_state_dict(torch.load('models/weights/resnet18_AT.pt'))
    elif model_name=='model3':
        model = SmallResNet()
        model.load_state_dict(torch.load('models/weights/res_small.pth'))
    elif model_name=='model4':
        model = WideResNet34()
        model.load_state_dict(filter_state_dict(torch.load('models/weights/trades_wide_resnet.pt')))
    elif model_name=='model5':
        model = WideResNet()
        model.load_state_dict(torch.load('models/weights/wideres34-10-pgdHE.pt'))
    elif model_name=='model6':
        model = WideResNet28()
        model.load_state_dict(filter_state_dict(torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt')))
    return model
コード例 #11
0
ファイル: main_resnets.py プロジェクト: n00blet/Models
def main():
    args = parse_args(args=sys.argv[1:])

    input_shape, n_epochs, n_top, units, pool, l2, patience = (
        args.input_shape, args.n_epochs, args.n_dense, args.units, args.pool,
        args.l2_reg, args.patience)

    print('Model training with parameters:')
    pprint(vars(args))

    model = ResNet34(input_shape=input_shape)
    model.create(pool='avg', n_top=n_top, units=units, l2_reg=l2)
    model.compile(optimizer=args.optimizer)

    ok = model.create_model_folder(root=join(MODELS_FOLDER, 'face_landmarks'),
                                   subfolder=args.identifier)

    if not ok:
        sys.exit(1)

    model.save_parameters(model.parameters_path)

    callbacks = [
        CSVLogger(filename=model.history_path),
        EarlyStopping(patience=patience, verbose=1),
        ModelCheckpoint(filepath=model.weights_path,
                        save_best_only=True,
                        save_weights_only=False)
    ]

    model.train(n_epochs=n_epochs,
                train_folder=LFPW_TRAIN,
                valid_folder=LFPW_VALID,
                callbacks=callbacks)

    avg_rmse = model.score(LFPW_VALID)
    print(f'Trained model validation RMSE: {avg_rmse:2.4f}')
    print(f'The folder with results: {model.subfolder}')
    print(f'Training history file: {model.history_path}')
コード例 #12
0
    train_flickr = FlickrDataset(args.flickr,
                                 class_mapping=train_sketches.class_mapping)
    dataset = ContrastiveDataset(flickr_dataset=train_flickr,
                                 sketches_dataset=train_sketches,
                                 n_similar=args.n_similar,
                                 m_different=args.m_different)

    n_val = int(len(dataset) * args.val_size)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=args.batch_size)
    val_loader = DataLoader(val, batch_size=args.batch_size // 4)

    # load backbones
    print("[*] Initializing weights...")
    imagenet_net = ResNet34()
    sketches_net = ResNet34()
    # sketches_net.load_state_dict(torch.load(args.sketches_backbone_weights))
    print("[+] Weights loaded")

    print("[*] Initializing model, loss and optimizer")
    contrastive_net = SiameseNetwork(sketches_net, imagenet_net)
    contrastive_net.to(args.device)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(contrastive_net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    else:
        optimizer = torch.optim.Adam(contrastive_net.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=args.t_0)
コード例 #13
0
from models import ResNet34
from config import DefaultConfig
from data.dataset import DogCat
from torch.utils.data import Dataset, DataLoader
import torch as t
from torch.autograd import Variable
import matplotlib.pyplot as plt
#from .models.ResNet34 import DogCat

opt = DefaultConfig()
print('----opt=', opt)
lr = opt.lr
print('---opt.train_data_root=', opt.train_data_root)

# step1: models
net = ResNet34.ResNet()
# print('----net=',net)
train_dataset = DogCat(opt.train_data_root, train=True)
val_dataset = DogCat(opt.train_data_root, train=False)

# step2: data set
train_dataloader = DataLoader(train_dataset,
                              opt.batch_size,
                              shuffle=True,
                              num_workers=opt.num_workers)
val_dataloader = DataLoader(val_dataset,
                            4,
                            shuffle=True,
                            num_workers=opt.num_workers)

# step3: target function and optimizer
コード例 #14
0
ファイル: train.py プロジェクト: pomelyu/cat_vs_dog
        if opt.use_gpu:
            data = data.cuda()
            label = label.cuda()

        prediction = model(data)

        confusion_matrix.add(prediction.detach().squeeze() >= 0.5,
                             label.long())

    cm_value = confusion_matrix.value()
    accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / cm_value.sum()

    return confusion_matrix, accuracy


if __name__ == "__main__":
    opt = TrainOptions().parse()

    train_dataloader = dataloader.create_train_dataloader(
        opt.train_data_path, opt.batch_size)
    valid_dataloader = dataloader.create_valid_dataloader(
        opt.train_data_path, opt.batch_size)

    model = ResNet34()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        model = model.cuda()

    train(model, train_dataloader, valid_dataloader, opt)
                                 num_workers=4,
                                 shuffle=False,
                                 drop_last=True,
                                 pin_memory=True)
    src_data_loader_eval = DataLoader(dataset_blur_val,
                                      batch_size=params.batch_size,
                                      num_workers=4,
                                      pin_memory=True)

    # src_data_loader = get_data_loader(params.src_dataset)
    # src_data_loader_eval = get_data_loader(params.src_dataset, train=False)
    # tgt_data_loader = get_data_loader(params.tgt_dataset)
    # tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False)

    # load models
    src_encoder = init_model(net=ResNet34(),
                             restore=params.src_encoder_restore)
    src_classifier = init_model(net=ResNetClassifier(),
                                restore=params.src_classifier_restore)
    blur_src_encoder = init_model(net=ResNet34(),
                                  restore=params.blur_src_encoder_restore)
    blur_src_classifier = init_model(
        net=ResNetClassifier(), restore=params.blur_src_classifier_restore)
    tgt_encoder = init_model(net=ResNet34(),
                             restore=params.tgt_encoder_restore)
    critic = init_model(Discriminator(input_dims=params.d_input_dims,
                                      hidden_dims=params.d_hidden_dims,
                                      output_dims=params.d_output_dims),
                        restore=params.d_model_restore)

    # train source model
コード例 #16
0
 def build_model(self):
     self.model = ResNet34(attr_dim=len(self.config.attrs)
                           if len(self.config.attrs) != 0 else 40)
コード例 #17
0
ファイル: main.py プロジェクト: xuyanwu/Complementary-GAN
def train_f_data_g(opt, Q):
    os.makedirs(os.path.join(
        opt.savingroot, opt.dataset,
        str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
        '_chkpts_fake_data'),
                exist_ok=True)

    if opt.data_r == 'MNIST':
        netd = D_MNIST(opt.ndf, opt.nc, num_classes=opt.num_class).cuda()
        netg = G_MNIST(opt.nz, opt.ngf, opt.nc).cuda()
    elif opt.data_r == 'CIFAR10':
        netd = ResNet18(opt.num_class).cuda(
        )  #DPN92().cuda()#D_CIFAR10(opt.ndf, opt.nc, num_classes=10).cuda()
        netg = Generator32(n_class=opt.num_class,
                           size=opt.image_size,
                           SN=True,
                           code_dim=opt.nz).cuda()
    else:
        netg = Generator(
            n_class=opt.num_class,
            size=opt.image_size,
            SN=True,
            code_dim=opt.nz
        ).cuda(
        )  #G_TINY_IMAGENET(opt.nz, opt.ngf, opt.nc, image_size=opt.image_size).cuda()
        if opt.image_size == 32:
            netd = ResNet34(opt.num_class).cuda()
        elif opt.image_size == 64:
            netd = ResNet34_64(opt.num_class).cuda()
        elif opt.image_size == 128:
            netd = ResNet34_128(opt.num_class).cuda()
    print(
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
            f'_chkpts/g_{opt.iter:03d}.pth'))
    netg.load_state_dict(
        torch.load(
            os.path.join(
                opt.savingroot, opt.dataset,
                str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
                f'_chkpts/g_{opt.iter:03d}.pth')))
    netg.eval()

    netd = nn.DataParallel(netd)
    netg = nn.DataParallel(netg)

    optd = optim.SGD(
        netd.module.parameters(), lr=opt.lr, momentum=0.9, weight_decay=opt.wd
    )  #optim.Adam(netd.parameters(), lr=0.0002, betas=(0.5, 0.999))  #

    print('training_start')
    step = 0
    acc = []

    test_loader = torch.utils.data.DataLoader(CIFAR10_Complementary(
        os.path.join(opt.savingroot, opt.data_r, 'data'),
        train=False,
        transform=transforms.Compose([
            transforms.Resize(opt.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])),
                                              batch_size=128,
                                              num_workers=2)

    fixed = embed_z(opt)
    # dataset = dset.CIFAR10(root=opt.dataroot, download=True, transform=tsfm)

    for epoch in range(opt.num_epoches):
        dataset = CIFAR10_Complementary(os.path.join(opt.savingroot,
                                                     opt.data_r, 'data'),
                                        transform=tsfm,
                                        p1=opt.p1,
                                        p2=opt.p2)

        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=2,
                                             pin_memory=True,
                                             worker_init_fn=np.random.seed)
        print(f'Epoch {epoch:03d}.')
        if epoch % int(opt.num_epoches / 3) == 0 and epoch != 0:
            for param_group in optd.param_groups:
                param_group['lr'] = param_group['lr'] / 10
                print(param_group['lr'])
        step = train_data_g(netd, netg, optd, epoch, step, opt, loader, Q)

        acc.append(test_acc(netd, test_loader))

    f = open(
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + 'Nacc_f_train.txt'), 'w')
    for cont in acc:
        f.writelines(str(cont) + '\n')

    f.close()
コード例 #18
0
ファイル: main.py プロジェクト: xuyanwu/Complementary-GAN
def train_gan(opt):

    os.makedirs(os.path.join(
        opt.savingroot, opt.dataset,
        str(opt.p1 * 100) + '%complementary/' + str(opt.p1) + '_images'),
                exist_ok=True)
    os.makedirs(os.path.join(
        opt.savingroot, opt.dataset,
        str(opt.p1 * 100) + '%complementary/' + str(opt.p1) + '_chkpts'),
                exist_ok=True)
    # if not os.path.exists(os.path.join(opt.savingroot,opt.data_r,'data','processed/training'+str(opt.p1)+str(opt.p2)+'.pt')):
    Q = generate_c_data(opt)
    # #Build networ
    if opt.data_r == 'MNIST':
        netd_g = D_MNIST(opt.ndf, opt.nc, num_classes=opt.num_class).cuda()
        netd_c = D_MNIST(opt.ndf, opt.nc, num_classes=opt.num_class).cuda()
        netg = G_MNIST(opt.nz, opt.ngf, opt.nc).cuda()
    elif opt.data_r == 'CIFAR10':
        netd_c = ResNet18(opt.num_class).cuda()
        netd_g = Discriminator32(n_class=opt.num_class,
                                 size=opt.image_size,
                                 SN=True).cuda()
        netg = Generator32(n_class=opt.num_class,
                           size=opt.image_size,
                           SN=True,
                           code_dim=opt.nz).cuda()
    else:
        netd_g = Discriminator(
            n_class=opt.num_class, size=opt.image_size, SN=True
        ).cuda(
        )  #D_TINY_IMAGENET(opt.ndf, opt.nc, num_classes=opt.num_class,image_size=opt.image_size).cuda()
        netg = Generator(
            n_class=opt.num_class,
            size=opt.image_size,
            SN=True,
            code_dim=opt.nz
        ).cuda(
        )  #G_TINY_IMAGENET(opt.nz, opt.ngf, opt.nc,num_class=opt.num_class,image_size=opt.image_size).cuda()

        if opt.image_size == 32:
            netd_c = ResNet34(opt.num_class).cuda()
        elif opt.image_size == 64:
            netd_c = ResNet34_64(opt.num_class).cuda()
        elif opt.image_size == 128:
            netd_c = ResNet34_128(opt.num_class).cuda()

    optd_g = optim.Adam(
        netd_g.parameters(), lr=0.0004, betas=(0.0, 0.9)
    )  # optim.SGD(netd.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)#
    optd_c = optim.SGD(
        netd_c.parameters(), lr=opt.lr, momentum=0.9, weight_decay=opt.wd
    )  #optim.Adam(netd_c.parameters(), lr=0.0002, betas=(0.5, 0.999),weight_decay=5e-4)  #
    optg = optim.Adam(
        netg.parameters(), lr=0.0001, betas=(0.0, 0.9)
    )  # optim.SGD(netg.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)#

    print('training_start')
    step = 0
    acc = []

    test_loader = torch.utils.data.DataLoader(CIFAR10_Complementary(
        os.path.join(opt.savingroot, opt.data_r, 'data'),
        train=False,
        transform=transforms.Compose([
            transforms.Resize(opt.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])),
                                              batch_size=128,
                                              num_workers=2)

    dataset = CIFAR10_Complementary(os.path.join(opt.savingroot, opt.data_r,
                                                 'data'),
                                    transform=tsfm,
                                    p1=opt.p1,
                                    p2=opt.p2)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=2,
                                         worker_init_fn=np.random.seed)
    #
    # for epoch in range(opt.num_epoches):
    #     print(f'Epoch {epoch:03d}.')
    #
    #     if epoch % int(opt.num_epoches/3) == 0 and epoch != 0:
    #         for param_group in optd_c.param_groups:
    #             param_group['lr'] = param_group['lr'] / 10
    #             print(param_group['lr'])
    #     step = train_c(epoch,netd_c, optd_c, loader, step,opt,Q)
    #     acc.append(test_acc(netd_c, test_loader))
    # f = open(os.path.join(opt.savingroot, opt.dataset, str(opt.p1 * 100) + '%complementary/' + 'acc.txt'), 'w')
    # for cont in acc:
    #     f.writelines(str(cont) + '\n')
    # f.close()

    netd_c.load_state_dict(
        torch.load(
            os.path.join(
                opt.savingroot, opt.dataset,
                str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
                f'_chkpts/d_{(opt.num_epoches-1):03d}.pth')))

    step = 0

    if opt.data_r == 'MNIST':
        dataset = dset.MNIST(root=opt.dataroot, download=True, transform=tsfm)
    elif opt.data_r == 'CIFAR10':
        dataset = dset.CIFAR10(root=opt.dataroot,
                               download=True,
                               transform=tsfm)
    elif opt.data_r == 'IMAGENET100':
        dataset = TINY_IMAGENET_Complementary_g(os.path.join(
            opt.savingroot, opt.data_r, 'data'),
                                                transform=tsfm,
                                                p1=opt.p1,
                                                p2=opt.p2)
    elif opt.data_r == 'VGG-FACE':
        dataset = TINY_IMAGENET_Complementary_g(os.path.join(
            opt.savingroot, opt.data_r, 'data'),
                                                transform=tsfm,
                                                p1=opt.p1,
                                                p2=opt.p2)

    train_g(netd_g, netd_c.eval(), netg, optd_g, optg, dataset, opt)

    return Q
コード例 #19
0
from models import ResNet34
from PIL import Image
import torch as t
from torch.autograd import Variable
import matplotlib.pyplot as plt

model = ResNet34.ResNet()
model.load_state_dict(t.load('params99.pth'))
# print('-----model=',model)
model.eval()

img_data = Image.open('/home/cat_and_dog/data/test_imgs/s0009_cat.jpg')
img_data.save('test.png')
print('--img_data=', img_data)

import matplotlib
im = matplotlib.image.imread('/home/cat_and_dog/data/test_imgs/s0009_cat.jpg')
print('---img=', im)