Beispiel #1
0
def main():
    model = se_resnet50(num_classes=1000)

    optimizer = optim.SGD(lr=0.6 / 1024 * args.batch_size, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.MultiStepLR([50, 70])

    c = [callbacks.AccuracyCallback(), callbacks.LossCallback()]
    r = reporter.TQDMReporter(range(args.epochs), callbacks=c)
    tb = reporter.TensorboardReporter(c)
    rep = callbacks.CallbackList(r, tb, callbacks.WeightSave("checkpoints"))

    if args.distributed:
        # DistributedSupervisedTrainer sets up torch.distributed
        if args.local_rank == 0:
            print("\nuse DistributedDataParallel")
        trainer = DistributedSupervisedTrainer(model, optimizer, F.cross_entropy, callbacks=rep, scheduler=scheduler,
                                               init_method=args.init_method, backend=args.backend)
    else:
        multi_gpus = torch.cuda.device_count() > 1
        if multi_gpus:
            print("\nuse DataParallel")
        trainer = SupervisedTrainer(model, optimizer, F.cross_entropy, callbacks=rep,
                                    scheduler=scheduler, data_parallel=multi_gpus)
    # if distributed, need to setup loaders after DistributedSupervisedTrainer
    train_loader, test_loader = imagenet_loaders(args.root, args.batch_size, distributed=args.distributed,
                                                 num_train_samples=args.batch_size * 10 if args.debug else None,
                                                 num_test_samples=args.batch_size * 10 if args.debug else None)
    for _ in r:
        trainer.train(train_loader)
        trainer.test(test_loader)
Beispiel #2
0
def main():
    if args.distributed:
        init_distributed()

    model = se_resnet50(num_classes=1000)

    optimizer = optim.SGD(lr=0.6 / 1024 * args.batch_size,
                          momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.MultiStepLR([50, 70])
    train_loader, test_loader = imagenet_loaders(args.root, args.batch_size, distributed=args.distributed,
                                                 num_train_samples=args.batch_size * 10 if args.debug else None,
                                                 num_test_samples=args.batch_size * 10 if args.debug else None)

    c = [callbacks.AccuracyCallback(), callbacks.AccuracyCallback(k=5),
         callbacks.LossCallback(),
         callbacks.WeightSave('.'),
         reporters.TensorboardReporter('.'),
         reporters.TQDMReporter(range(args.epochs))]

    with SupervisedTrainer(model, optimizer, F.cross_entropy,
                           callbacks=c,
                           scheduler=scheduler,
                           ) as trainer:
        for _ in c[-1]:
            trainer.train(train_loader)
            trainer.test(test_loader)
    def __init__(self,
                 arch='resnet18',
                 with_se=False,
                 with_rpp=False,
                 use_cbam=False,
                 class_num=395,
                 reduction=16):
        super(visible_net_resnet, self).__init__()
        if arch == 'resnet18':
            print("visible_net with resenet18 architecture setting.....")
            model_ft = models.resnet18(pretrained=True)
        elif arch == 'resnet50':
            print("visible_net with resenet50 architecture setting.....")
            # model_ft = resnet.resnet50(pretrained=True,with_se=with_se)
            model_ft = resnet.pcb_rpp(pretrained=True,
                                      with_rpp=with_rpp,
                                      use_cbam=use_cbam,
                                      class_num=class_num)
        #add by zc
        elif arch == 'se_resnet50':
            print("visible_net with se_resnet50 architecture setting.....")
            model_ft = se_resnet50(pretrained=True)

        elif arch == 'pcb_rpp':
            model_ft = resnet.pcb_rpp(pretrained=True,
                                      with_rpp=with_rpp,
                                      use_cbam=use_cbam,
                                      class_num=class_num)

        elif arch == 'pcb_pyramid':
            model_ft = resnet.pcb_rpp(pretrained=True,
                                      with_rpp=with_rpp,
                                      use_cbam=use_cbam,
                                      class_num=class_num)

        elif arch == 'cbam':
            model_ft = resnet.resnet50(pretrained=True, use_cbam=use_cbam)

        elif arch == 'scpnet':
            model_ft = get_scp_model(pretrained=True, nr_class=class_num)
        ##end by zc
        # avg pooling to global pooling
        # if arch == 'pcb_rpp':
        #     if with_rpp:
        #         model_ft.avgpool = resnet.RPP()
        #         print("-------RPP module in visible starting------")
        #     else:
        #         model_ft.avgpool = nn.AdaptiveAvgPool2d((6, 1))
        #         print("-------No RPP module in visible------")
        # else:
        #     model_ft.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.visible = model_ft
        self.backbone = model_ft.backbone
        self.dropout = nn.Dropout(p=0.5)
        #add by zc
        self.with_se = with_se
        self.arch = arch
        self.feature = FeatureBlock(2048, 512, dropout=0.5)
        '''
Beispiel #4
0
def main():
    if is_distributed():
        init_distributed()

    model = se_resnet50(num_classes=1000)

    optimizer = optim.SGD(lr=0.6 / 1024 * args.batch_size,
                          momentum=0.9,
                          weight_decay=1e-4)
    scheduler = lr_scheduler.MultiStepLR([50, 70])
    train_loader, test_loader = DATASET_REGISTRY("imagenet")(args.batch_size)

    c = [
        callbacks.AccuracyCallback(),
        callbacks.AccuracyCallback(k=5),
        callbacks.LossCallback(),
        callbacks.WeightSave("."),
        reporters.TensorboardReporter("."),
        reporters.TQDMReporter(range(args.epochs)),
    ]

    with SupervisedTrainer(
            model,
            optimizer,
            F.cross_entropy,
            callbacks=c,
            scheduler=scheduler,
    ) as trainer:
        for _ in c[-1]:
            trainer.train(train_loader)
            trainer.test(test_loader)
Beispiel #5
0
def main():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    valid_transform = transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(), normalize])

    train_dataset = ImageFolder(args.data, train_transform)
    valid_dataset = ImageFolder(args.data, valid_transform)

    num_samples = int(len(train_dataset) / 10)
    indices = list(range(num_samples))
    split = int(np.floor(0.1 * num_samples))
    np.random.shuffle(indices)
    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              sampler=train_sampler,
                              num_workers=4)
    valid_loader = DataLoader(valid_dataset,
                              args.batch_size,
                              sampler=valid_sampler,
                              num_workers=4)
    print("num data:", num_samples)
    print("num train batches:", len(train_loader))
    print("num test batches:", len(valid_loader))
    # return

    # train_loader, test_loader = cifar10_loaders(args.batch_size)

    model = se_resnet50(num_classes=42)
    # model.load_state_dict(torch.load("seresnet50-60a8950a85b2b.pkl"))
    optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(80, 0.1)
    tqdm_rep = reporters.TQDMReporter(range(args.epochs),
                                      callbacks.AccuracyCallback())
    _callbacks = [tqdm_rep, callbacks.AccuracyCallback()]
    with Trainer(model,
                 optimizer,
                 F.cross_entropy,
                 scheduler=scheduler,
                 callbacks=_callbacks) as trainer:
        for _ in tqdm_rep:
            trainer.train(train_loader)
            trainer.test(valid_loader)
            torch.save(trainer.model.state_dict(), "se_resnet50.pkl")
Beispiel #6
0
def main():
    global args
    conf = configparser.ConfigParser()
    args = parser.parse_args()

    conf.read(args.config)
    DATA_DIR = conf.get("subject_level", "data")
    LABEL_DIR = conf.get("subject_level", "label")
    create_dir_not_exist(LABEL_DIR)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not args.model:
        print("Usage: --model -m\n\tpath to the model")
        sys.exit()

    model = se_resnet50(num_classes=3)

    model = DataParallel(model)

    trained_net = torch.load(args.model)
    model.load_state_dict(trained_net['state_dict'])
    model = model.cuda()

    result_list = []

    for class_index, class_name in enumerate(CLASSES_NAME):

        class_dir = os.path.join(DATA_DIR, class_name)
        patient_list = os.listdir(class_dir)
        patient_list = [
            os.path.join(class_dir, patient) for patient in patient_list
            if os.path.isdir(os.path.join(class_dir, patient))
        ]
        print('---------- {} ----------'.format(class_name))
        for i, patient_dir in enumerate(patient_list):
            slice_list = os.listdir(patient_dir)
            if len(slice_list) < 10:
                continue
            checkSuffix(slice_list)
            slice_list = [s for s in slice_list if s[:2] != "._"]
            slice_list = [os.path.join(patient_dir, s) for s in slice_list]
            scorelist = test(slice_list, class_index, model)
            result = np.insert(scorelist, 0, class_index)
            print('{} ----- {}'.format(i, len(patient_list)))
            result_list.append(list(result))

    with open(os.path.join(LABEL_DIR, 'result.csv'), 'w') as f:
        f_csv = csv.writer(f)
        f_csv.writerows(result_list)
    def __init__(self,
                 arch='resnet18',
                 with_se=False,
                 with_rpp=False,
                 use_cbam=False,
                 class_num=395,
                 reduction=16):
        super(thermal_net_resnet, self).__init__()
        if arch == 'resnet18':
            print("thermal_net with resenet18 architecture setting.....")
            model_ft = models.resnet18(pretrained=True)
        elif arch == 'resnet50':
            print("thermal_net with resenet50 architecture setting.....")
            # model_ft = resnet.resnet50(pretrained=True,with_se=with_se)
            model_ft = resnet.pcb_rpp(pretrained=True,
                                      with_rpp=with_rpp,
                                      use_cbam=use_cbam,
                                      class_num=class_num)
        #add by zc
        elif arch == 'se_resnet50':
            print("thermal_net with se_resnet50 architecture setting.....")
            model_ft = se_resnet50(pretrained=True)

        elif arch == 'pcb_rpp':
            model_ft = resnet.pcb_rpp(pretrained=True,
                                      with_rpp=with_rpp,
                                      use_cbam=use_cbam,
                                      class_num=class_num)

        elif arch == 'pcb_pyramid':
            model_ft = resnet.pcb_rpp(pretrained=True,
                                      with_rpp=with_rpp,
                                      use_cbam=use_cbam,
                                      class_num=class_num)

        elif arch == 'cbam':
            model_ft = resnet.resnet50(pretrained=True, use_cbam=use_cbam)

        elif arch == 'scpnet':
            model_ft = get_scp_model(pretrained=True, nr_class=class_num)
        # avg pooling to global pooling

        self.thermal = model_ft
        self.backbone = model_ft.backbone
        self.dropout = nn.Dropout(p=0.5)
        #add by zc
        self.with_se = with_se
        self.arch = arch
        self.feature = FeatureBlock(2048, 512, dropout=0.5)
Beispiel #8
0
def main():
    train_loader, test_loader = get_dataloader(args.batch_size, args.root)
    gpus = list(range(torch.cuda.device_count()))
    se_resnet = nn.DataParallel(se_resnet50(num_classes=1000),
                                device_ids=gpus)
    optimizer = optim.SGD(lr=0.6 / 1024 * args.batch_size, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.StepLR(30, gamma=0.1)
    weight_saver = callbacks.WeightSave("checkpoints")
    tqdm_rep = reporter.TQDMReporter(range(args.epochs), callbacks=[callbacks.AccuracyCallback()])

    trainer = Trainer(se_resnet, optimizer, F.cross_entropy, scheduler=scheduler,
                      callbacks=callbacks.CallbackList(weight_saver, tqdm_rep))
    for _ in tqdm_rep:
        trainer.train(train_loader)
        trainer.test(test_loader)
Beispiel #9
0
	def __init__(self):
		super(SE_Resnet_feat, self).__init__()
		self.model = se_resnet50()
		load = torch.load(path_pth)
		self.model.load_state_dict(load)
		print('SEResNet50 models loaded!')
		
		# using Grad-CAM with softmax weight
		params = list(self.model.parameters())
		self.weight_softmax = np.squeeze(params[-2].data.numpy())
		
		self.test = True
		self.model.eval()
		if use_gpu:
			self.model = self.model.cuda()
Beispiel #10
0
def main():
    global args
    conf = configparser.ConfigParser()
    args = parser.parse_args()

    conf.read(args.config)
    TEST_DIR = conf.get("senet", "test")
    LOG_DIR = conf.get("senet", "log")
    create_dir_not_exist(LOG_DIR)
    test_list = [os.path.join(TEST_DIR, item) for item in os.listdir(TEST_DIR)]
    test_list = checkSuffix(test_list)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not args.model:
        print("Usage: --model -m\n\tpath to the model")
        sys.exit()

    model = se_resnet50(num_classes=3)
    criterion = nn.CrossEntropyLoss().cuda()

    model = DataParallel_withLoss(model, criterion)

    trained_net = torch.load(args.model)
    model.load_state_dict(trained_net['state_dict'])
    model = model.cuda()

    vote_pred = np.zeros(len(test_list))
    vote_score = np.zeros(len(test_list))

    targetlist, scorelist, predlist = test(test_list, model, criterion)

    report = classification_report(y_true=targetlist,
                                   y_pred=predlist,
                                   target_names=["Normal", "CAP", "COVID-19"])
    print(report)
Beispiel #11
0
def get_model(model_name=None):
    if not model_name:
        model_name = args.model
    if model_name == 'resnet18':
        model = torchvision.models.resnet18(pretrained=True)
    elif model_name == 'alexnet':
        model = torchvision.models.alexnet(pretrained=True)
    elif model_name == 'squeezenet':
        model = torchvision.models.squeezenet1_0(pretrained=True)
    elif model_name == 'vgg16':
        model = torchvision.models.vgg16(pretrained=True)
    elif model_name == 'densenet':
        model = torchvision.models.densenet161(pretrained=True)
    elif model_name == 'inception':
        model = torchvision.models.inception_v3(pretrained=True)
    elif model_name == 'googlenet':
        model = torchvision.models.googlenet(pretrained=True)
    elif model_name == 'shufflenet':
        model = torchvision.models.shufflenet_v2_x1_0(pretrained=True)
    elif model_name == 'mobilenet':
        model = torchvision.models.mobilenet_v2(pretrained=True)
    elif model_name == 'resnet50_32x4d':
        model = torchvision.models.resnext50_32x4d(pretrained=True)
    elif model_name == 'wide_resnet50_2':
        model = torchvision.models.wide_resnet50_2(pretrained=True)
    elif model_name == 'mnasnet':
        model = torchvision.models.mnasnet1_0(pretrained=True)
    elif model_name == 'resnext50_32x4d_ssl':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnext50_32x4d_ssl')
    elif model_name == 'resnext50_32x4d_swsl':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnext50_32x4d_swsl')
    elif model_name == 'resnet50_swsl':
        model = torch.hub.load(
            'facebookresearch/semi-supervised-ImageNet1K-models',
            'resnet50_swsl')

    elif 'seresnet50' in model_name:
        model = se_resnet50(num_classes=1000)
        model.load_state_dict(
            torch.load("../checkpoint/seresnet50-60a8950a85b2b.pkl"))
    elif model_name == 'T2t_vit_t_14' or model_name == 'T2t_vit_t_24':
        model = create_model(
            model_name,
            pretrained=False,
            num_classes=args.num_classes,
            in_chans=3,
        )
        load_checkpoint(model, checkpoint_paths[model_name], True)
    else:
        model = create_model(
            model_name,
            pretrained=args.pretrained,
            num_classes=args.num_classes,
            in_chans=3,
        )
        if not args.pretrained:
            if not args.set_temperature:
                load_checkpoint(model, checkpoint_paths[model_name], True)
            else:
                load_checkpoint(
                    model, checkpoint_paths[
                        f"{model_name}_tem{args.set_temperature}"], True)
                for i in range(len(model.blocks)):
                    model.blocks[i].attn.scale = 768**(-1 /
                                                       args.set_temperature)
                print("Set temperature to: ", model.blocks[0].attn.scale)

    return model.eval().to(device)
    def preprocess(cls, pil_img):
        img_nd = np.array(pil_img)

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = se_resnet50(num_classes=1)
    model = model.to(device)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(checkpoints))
    filename = os.listdir(img_path)
    img_path_list = []
    for i in filename:
        img_path_list.append(os.path.join(img_path, i))
    dataset = Xinguan(img_path_list)
    dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=len(dataset),
            shuffle=False,
            num_workers=0
        )
    images = next(iter(dataloader))['image'].to(device)
Beispiel #13
0
    val_dataset = MyDataset(image_path_list,
                            label_list,
                            transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=opt.BATCHSIZE,
                            shuffle=False,
                            num_workers=opt.NUMWORKER,
                            pin_memory=True)

    if opt.NETWORK == 'resnet':
        net = resnet50(pretrained=True)
        fc_features = net.fc.in_features
        net.fc = torch.nn.Linear(fc_features, 230)
    elif opt.NETWORK == 'senet':
        net = se_resnet50(num_classes=230)
    elif opt.NETWORK == 'repvgg':
        net = create_RepVGG_B1g2(num_classes=230)
    else:
        net = resnet50(num_classes=230)

    net = net.to(DEVICE)

    # freezing conv features
    # for k, v in net.named_parameters():
    #    if k not in ['fc.weight', 'fc.bias']:
    #        v.requires_grad = False

    loss_func = CrossEntropyLoss(weight=weight_list).to(DEVICE)
    optimizer = Adam(net.parameters(), lr=opt.LR, weight_decay=0.0001)
    # optimizer = AdamW(net.parameters())
Beispiel #14
0
    dataloders = get_brats_train_loaders(path, test_path)
    # dataset_sizes = {'train': 26778, 'val': 6695}
    dataset_sizes = {'train': 16736, 'val': 16736, 'test': 4247}
    # dataset_sizes = {'train': 26777, 'val': 6695, 'test': 3791}
    # use gpu or not
    use_gpu = torch.cuda.is_available()
    logger.info("use_gpu:{}".format(use_gpu))

    # get model
    script_name = '_'.join([
        args.network.strip().split('_')[0],
        args.network.strip().split('_')[1]
    ])

    if script_name == "se_resnet50":
        model = se_resnet50(num_classes=args.num_class)
    else:
        raise Exception(
            "Please give correct network name such as se_resnet_xx or se_rexnext_xx"
        )

    # define loss function
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(model.parameters(),
                             lr=args.lr,
                             momentum=0.9,
                             weight_decay=0.00004)

    # Decay LR by a factor of 0.1 every 7 epochs
def main():
    global args, best_prec1
    best_prec1 = 1e6
    args = parser.parse_args()
    args.original_lr = 1e-6
    args.lr = 1e-6
    args.momentum  = 0.95
    args.decay  = 5*1e-4
    args.start_epoch   = 0
    args.epochs = 50
    args.steps = [-1,1,20,40]
    args.scales = [1,1,0.5,0.5]
    args.workers = 0
    args.seed = time.time()
    args.print_freq = 30
    wandb.config.update(args)
    wandb.run.name = f"Default_{wandb.run.name}" if (args.task == wandb.run.name) else f"{args.task}_{wandb.run.name}"

    conf= configparser.ConfigParser()

    conf.read(args.config) 
    TRAIN_DIR = conf.get("senet", "train") 
    VALID_DIR = conf.get("senet", "valid") 
    TEST_DIR = conf.get("senet", "test") 
    LOG_DIR = conf.get("senet", "log") 
    create_dir_not_exist(LOG_DIR)
    train_list = [os.path.join(TRAIN_DIR, item) for item in os.listdir(TRAIN_DIR)]
    train_list=checkSuffix(train_list)
    val_list = [os.path.join(VALID_DIR, item) for item in os.listdir(VALID_DIR)]
    val_list=checkSuffix(val_list)
    data_list= train_list+val_list
    random.shuffle(data_list)
    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.cuda.manual_seed(args.seed)

    
    model = se_resnet50(num_classes=3)
    model = model.cuda()
    
    criterion = nn.CrossEntropyLoss().cuda()
    model = DataParallel_withLoss(model, criterion)

    for i in range(args.k_fold):
        train_list,val_list=get_k_fold_data(i,data_list)
        args.lr=args.original_lr
        best_prec1 = 1e6
        optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay)

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch)
            train_loss = train(train_list, model, criterion, optimizer, epoch)
            prec1 = validate(val_list, model, criterion, epoch)
            with open(os.path.join(LOG_DIR, args.task + ".txt"), "a") as f:
                f.write("K "+str(i) +" epoch " + str(epoch) +  "  TrainLoss: " +str(float(train_loss))+
                "  ValLoss: " +str(float(prec1)))
                f.write("\n")
            wandb.log({'K':i,'epoch': epoch, 'TrainCEloss': train_loss,'ValCEloss':prec1})
            wandb.save(os.path.join(LOG_DIR, args.task + ".txt"))
            is_best = prec1 < best_prec1
            best_prec1 = min(prec1, best_prec1)
            print(' * best CELoss {CELoss:.3f} '.format(CELoss=best_prec1))
            save_checkpoint({
                'k':i,
                'epoch': epoch + 1,
                'arch': args.pre,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best,args.task, epoch = epoch,path=os.path.join(LOG_DIR, args.task))
Beispiel #16
0
    # creative logger
    logger = get_logger('SENetTesting', file_name='output')

    # read data
    path = args.data_path
    test_path = args.test_path
    dataloders = get_brats_train_loaders(path, test_path)
    dataset_sizes = {'train': 16736, 'val': 16736, 'test': 3791}

    # define loss function
    criterion = nn.CrossEntropyLoss()

    use_gpu = torch.cuda.is_available()
    logger.info("use_gpu:{}".format(use_gpu))

    model1 = se_resnet50(num_classes=args.num_class)
    model2 = se_resnet50(num_classes=args.num_class)
    if os.path.isfile(args.resume1):
        logger.info(("=> loading checkpoint '{}'".format(args.resume1)))
        state = torch.load(args.resume1)
        try:
            model1.load_state_dict(state['model_state_dict'])
        except BaseException as e:
            print('Failed to do something: ' + str(e))
    else:
        logger.info(("=> no checkpoint found at '{}'".format(args.resume1)))

    if os.path.isfile(args.resume2):
        logger.info(("=> loading checkpoint '{}'".format(args.resume2)))
        state = torch.load(args.resume2)
        try:
Beispiel #17
0
import torch.nn.functional as F
from listfile import listfile
from DataLoader import Trainloader
from senet.se_resnet import se_resnet50

train_imgs, train_labels, val_imgs, val_labels = listfile('./dataset')

TrainImgLoader = torch.utils.data.DataLoader(
         Trainloader(train_imgs, train_labels),
         batch_size=32, shuffle= True, num_workers= 8, drop_last=False)

TestImgLoader = torch.utils.data.DataLoader(
         Trainloader(val_imgs, val_labels),
         batch_size=16, shuffle= False, num_workers= 8, drop_last=False)

model = se_resnet50(num_classes=29, pretrained=False)
model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.1)
loss_func = F.cross_entropy

def train(img, label):
    model.train()
    img = Variable(torch.FloatTensor(img))
    label = Variable(torch.LongTensor(label))
    img = img.cuda()
    label = label.cuda()

    optimizer.zero_grad()
    predict = model(img)
    predict = F.softmax(predict, dim=1)