コード例 #1
0
 def _load_pretrained_weight(self):
     ## trick: get 4 channels weights from pretrained resnet
     _net = se_resnet50(pretrained='imagenet', input_channel=3)
     state_dict = _net.state_dict().copy()
     layer0_weights = state_dict['layer0.conv1.weight']
     print('raw_weight size: ', layer0_weights.size())
     layer0_weights_new = torch.nn.Parameter(
         torch.cat((layer0_weights, layer0_weights[:, :1, :, :]), dim=1))
     print('new_weight size: ', layer0_weights_new.size())
     new_state_dict = OrderedDict(('layer0.conv1.weight', layer0_weights_new) if key == 'layer0.conv1.weight' \
                                  else (key, value) for key, value in state_dict.items())
     ##
     net = se_resnet50(pretrained=None, input_channel=4)
     net.load_state_dict(new_state_dict)
     return net
コード例 #2
0
 def __init__(
     self,
     model_path='./models/seresnet50/seresnet50_final.pth',
     img_size=224,
 ):
     self.model_path = model_path
     self.img_size = img_size
     self.device = torch.device(
         'cuda' if torch.cuda.is_available() else 'cpu')
     print('devices=', self.device)
     self.net = se_resnet.se_resnet50(num_classes=4).to(self.device)
     self.net.load_state_dict(torch.load(self.model_path))
     self.net.eval()
     to_bgr_transform = transforms.Lambda(lambda x: x[[2, 1, 0]])
     self.transform = transforms.Compose([
         transforms.ToPILImage(),
         PadImage(),
         transforms.Resize([self.img_size, self.img_size], interpolation=3),
         transforms.ToTensor(),
         # to_bgr_transform,
         transforms.Normalize((0.5, ), (0.5, ))
     ])
     #for one image crop to a batch
     self.transform_crop = transforms.Compose([
         transforms.ToPILImage(),
         CropPadImage(),
         transforms.Resize([self.img_size, self.img_size], interpolation=3),
         transforms.ToTensor(), to_bgr_transform,
         transforms.Normalize((0.5, ), (0.5, ))
     ])
コード例 #3
0
ファイル: imagenet.py プロジェクト: lxtGH/senet.pytorch
def main(batch_size, data_root):
    transform_train = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    transform_test = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    traindir = os.path.join(data_root, 'train')
    valdir = os.path.join(data_root, 'val')
    train = datasets.ImageFolder(traindir, transform_train)
    val = datasets.ImageFolder(valdir, transform_test)
    train_loader = torch.utils.data.DataLoader(train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=8)
    test_loader = torch.utils.data.DataLoader(val,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=8)
    se_resnet = se_resnet50(num_classes=1000)
    optimizer = optim.SGD(params=se_resnet.parameters(),
                          lr=0.6,
                          momentum=0.9,
                          weight_decay=1e-4)
    scheduler = StepLR(optimizer, 30, gamma=0.1)
    trainer = Trainer(se_resnet, optimizer, F.cross_entropy, save_dir=".")
    trainer.loop(100, train_loader, test_loader, scheduler)
コード例 #4
0
def main(batch_size, root, lrate):
    #####################################################################
    "The implementation of tensorboardX and topK accuracy is in utils.py"
    #####################################################################

    # get checkpoint information
    checkpoint_newest = get_checkPoint("./lr" + str(lrate) + "/checkpoint/")

    #TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())
    # write log and visualize the losses of batches of training and testing
    TIMESTAMP = ""
    writer1 = SummaryWriter('./lr' + str(lrate) + '/tensorboard_log/batch/' +
                            TIMESTAMP)
    # write log and visualize the accuracy of batches of training and testing
    writer2 = SummaryWriter('./lr' + str(lrate) + '/tensorboard_log/epoch/' +
                            TIMESTAMP)

    train_loader, test_loader = get_dataloader(batch_size, root)
    gpus = list(range(torch.cuda.device_count()))

    # initialize your net/optimizer
    seresnet50 = nn.DataParallel(se_resnet50(num_classes=340), device_ids=gpus)
    optimizer = optim.SGD(params=seresnet50.parameters(),
                          lr=lrate / 1024 * batch_size,
                          momentum=0.9,
                          weight_decay=1e-4)

    # No existed checkpoint
    if checkpoint_newest == 0:
        scheduler = optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1)
        trainer = Trainer(seresnet50,
                          optimizer,
                          F.cross_entropy,
                          save_dir="./lr" + str(lrate) + "/checkpoint/",
                          writer1=writer1,
                          writer2=writer2,
                          save_freq=1)
        trainer.loop(50, train_loader, test_loader, 1, scheduler)
    # load existed checkpoint
    else:
        print("The path of the pretrained model %s" % checkpoint_newest)
        print("load pretrained model......")
        checkpoint = torch.load(checkpoint_newest)
        seresnet50.load_state_dict(checkpoint['weight'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              30,
                                              gamma=0.1,
                                              last_epoch=checkpoint['epoch'])
        print("The current epoch is %d" % checkpoint['epoch'])
        trainer = Trainer(seresnet50,
                          optimizer,
                          F.cross_entropy,
                          save_dir="./lr" + str(lrate) + "/checkpoint/",
                          writer1=writer1,
                          writer2=writer2,
                          save_freq=1)
        trainer.loop(100, train_loader, test_loader, checkpoint['epoch'] + 1,
                     scheduler)
コード例 #5
0
ファイル: imagenet.py プロジェクト: wentaozhu/senet.pytorch
def main(batch_size, root):
    train_loader, test_loader = get_dataloader(batch_size, root)
    _se_resnet = se_resnet50(num_classes=1000)
    se_resnet = nn.DataParallel(_se_resnet, device_ids=[0, 1])
    optimizer = optim.SGD(params=se_resnet.parameters(), lr=0.6, momentum=0.9, weight_decay=1e-4)
    scheduler = StepLR(optimizer, 30, gamma=0.1)
    trainer = Trainer(se_resnet, optimizer, F.cross_entropy, save_dir=".")
    trainer.loop(100, train_loader, test_loader, scheduler)
コード例 #6
0
def main(batch_size, root):
    train_loader, test_loader = get_dataloader(batch_size, root)
    gpus = list(range(torch.cuda.device_count()))
    se_resnet = nn.DataParallel(se_resnet50(num_classes=345), device_ids=gpus)
    optimizer = optim.SGD(params=se_resnet.parameters(),
                          lr=0.6 / 1024 * batch_size,
                          momentum=0.9,
                          weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1)
    trainer = Trainer(se_resnet, optimizer, F.cross_entropy, save_dir=".")
    trainer.loop(100, train_loader, test_loader, scheduler)
コード例 #7
0
def train_cnn_ivr():
    lowAccLabel_fp = '../data/list_tc/label/accLeccThan20Label_filter.txt'
    loaders, cnnidx2label = load_data_for_training_cnn(
        batch_size=16 * 1, lowAccLabel_fp=lowAccLabel_fp)

    # model = InceptionResNetV2(num_classes=365, num_feature=1024, drop_rate=0.2)
    model = se_resnet50(num_classes=365)
    criterion_cent = CenterLoss(num_classes=365, feat_dim=1024, use_gpu=False)

    DEVICE = torch.device('cuda:0')

    train_cnn(model,
              criterion_cent,
              loaders['train_cnn'],
              loaders['val_cnn'],
              cnnidx2label,
              DEVICE,
              multi_gpu=None,
              repick=True)
コード例 #8
0
def feature_extract():
    # model = InceptionResNetV2(num_classes=365, num_feature=2048)
    model = se_resnet50(num_classes=365)
    DEVICE_ID = 0
    device = torch.device(f'cuda:{DEVICE_ID}')
    model.load_state_dict(
        torch.load('../data/classficaData/B3_IVR2_1/pth/irv2_9_0.678389.pth'))
    model = model.to(device)

    loaders, cnnidx2label = load_data_for_feature_extract()
    print("load data is  ok")
    '''
    extract_features(model, loaders['train_A'], device, '../data/train_A_feat.npy')
    extract_features(model, loaders['train_B'], device, '../data/train_B_feat.npy')
    extract_features(model, loaders['test_A'], device, '../data/test_A_feat.npy')
    extract_features(model, loaders['test_B'], device, '../data/test_B_feat.npy')
    '''
    extract_features(model, loaders['train_C'], device,
                     '../data/train_C_feat.npy')
    extract_features(model, loaders['test_C'], device,
                     '../data/test_C_feat.npy')
コード例 #9
0
train_dl = DataLoader(label_lists.train,
                      num_workers=8,
                      batch_sampler=BatchSampler(
                          RandomSamplerWithEpochSize(label_lists.train,
                                                     epoch_size), bs, True))
valid_dl = DataLoader(label_lists.valid, bs, False, num_workers=8)
test_dl = DataLoader(label_lists.test, bs, False, num_workers=8)
data_bunch = ImageDataBunch(train_dl, valid_dl, test_dl)

classes = data_bunch.classes
pd.to_pickle(classes, PATH / "classes.pkl")

from fastai.callbacks import SaveModelCallback, EarlyStoppingCallback
import sys
sys.path.append("./senet.pytorch/")
from se_resnet import se_resnet50
model = se_resnet50(340)
learn = Learner(
    data_bunch,
    model,
    metrics=[accuracy, map3],
    callback_fns=[partial(SaveModelCallback, every="epoch", name="senet-v2")])

ckpt = "final-senet-stage-1"
print(f"Loading ckpt : {ckpt}")
learn.load(ckpt)
learn.fit_one_cycle(72, max_lr=5e-4)

name = 'final-senet-stage-1'
learn.save(f'{name}')
コード例 #10
0
ファイル: test.py プロジェクト: xiangwenliu/SE-ResNet-pytorch
def run_on_opencv_imgbatch(img):
    images = [transform_crop(img) for i in range(3)]
    images = torch.cat([t.unsqueeze(0) for t in images], 0)
    images = images.to(device)
    net.eval()
    outputs = net(images)
    preds = torch.softmax(outputs, 1)
    mea = torch.mean(preds, 0)
    ind = mea.argmax()
    label = ind.item()
    return label


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = se_resnet.se_resnet50(num_classes=args.class_num).to(device)
    # input = torch.randn(8, 3, 214, 214).to(device)
    # outputs = net(input)
    # print(outputs.size())
    net.load_state_dict(torch.load(args.model_path))
    dic_img = img_label_dic(args.eval_list)
    total = 0
    correct = 0
    for k, v in dic_img.items():
        print(k)
        imgpath = args.eval_imgpath + k
        if not os.path.exists(imgpath):
            print('not exist image {}'.format(k))
            continue
        pred = run_on_opencv_imgbatch(cv2.imread(imgpath))
        labels = v
コード例 #11
0
            transforms.Resize(224),
            transforms.ToTensor(),
            normalize,
            ]))
val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=test_batch, shuffle=False,
        num_workers=num_workers['val'], pin_memory=True)



dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}

dataloaders = {'train': train_loader, 'val': val_loader}

model_ft = se_resnet50(num_classes=1000)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
print(sum(p.numel() for p in model_ft.parameters() if p.requires_grad))
#Loss Function
criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
Init_lr = 0.005
optimizer_ft = optim.SGD(model_ft.parameters(), lr=Init_lr, momentum=0.9)


def adjust_learning_rate(optimizer, epoch, Init_lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = Init_lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
コード例 #12
0
    def __init__(self, classCount, isTrained):

        super(SE_ResNet50, self).__init__()

        self.se_resnet50 = se_resnet.se_resnet50(num_classes = classCount,pretrained=isTrained)
コード例 #13
0
def create_seresnet50():
    return se_resnet50(n_classes)
コード例 #14
0
class test_model(nn.Module):
    def __init__(self, opts, device, model_path):
        super(test_model, self).__init__()
        self.opts = opts
        self.device = device
        self.model_path = model_path
        # model
        self.network = self.model_choice(self.opts.case)
        # loss function
        self.loss_func = torch.nn.CrossEntropyLoss().to(device)

    def model_choice(self, case):
        case = case.lower()
        # resnet
        if case == 'resnet50':
            self.model = resnet50(pretrained=self.opts.pretrained,
                                  num_classes=self.opts.num_classes,
                                  model_path=self.model_path).to(self.device)
        if case == 'resnet101':
            self.model = resnet101(pretrained=self.opts.pretrained,
                                   num_classes=self.opts.num_classes,
                                   model_path=self.model_path).to(self.device)
        if case == 'resnet152':
            self.model = resnet152(pretrained=self.opts.pretrained,
                                   num_classes=self.opts.num_classes,
                                   model_path=self.model_path).to(self.device)

        # aa_resnet
        if case == 'aa_resnet50':
            self.model = aa_resnet50(pretrained=self.opts.pretrained,
                                     num_classes=self.opts.num_classes,
                                     model_path=self.model_path).to(
                                         self.device)
        if case == 'aa_resnet101':
            self.model = aa_resnet101(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.model_path).to(
                                          self.device)
        if case == 'aa_resnet152':
            self.model = aa_resnet152(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.model_path).to(
                                          self.device)

        # se_resnet
        if case == 'se_resnet50':
            self.model = se_resnet50(pretrained=self.opts.pretrained,
                                     num_classes=self.opts.num_classes,
                                     model_path=self.model_path).to(
                                         self.device)
        if case == 'se_resnet101':
            self.model = se_resnet101(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.model_path).to(
                                          self.device)
        if case == 'se_resnet152':
            self.model = se_resnet152(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.model_path).to(
                                          self.device)
        return self.model
コード例 #15
0
class train_model(nn.Module):
    def __init__(self, opts, device):
        super(train_model, self).__init__()
        self.opts = opts
        self.device = device
        # model
        self.network = self.model_choice(self.opts.case)
        # optimizer
        self.optimizer = torch.optim.Adam(self.network.parameters(),
                                          lr=self.opts.lr)
        # loss function
        self.loss_func = torch.nn.CrossEntropyLoss().to(device)

    def model_choice(self, case):
        case = case.lower()
        # resnet
        if case == 'resnet50':
            self.model = resnet50(pretrained=self.opts.pretrained,
                                  num_classes=self.opts.num_classes,
                                  model_path=self.opts.checkpoint).to(
                                      self.device)
        if case == 'resnet101':
            self.model = resnet101(pretrained=self.opts.pretrained,
                                   num_classes=self.opts.num_classes,
                                   model_path=self.opts.checkpoint).to(
                                       self.device)
        if case == 'resnet152':
            self.model = resnet152(pretrained=self.opts.pretrained,
                                   num_classes=self.opts.num_classes,
                                   model_path=self.opts.checkpoint).to(
                                       self.device)

        # aa_resnet
        if case == 'aa_resnet50':
            self.model = aa_resnet50(pretrained=self.opts.pretrained,
                                     num_classes=self.opts.num_classes,
                                     model_path=self.opts.checkpoint).to(
                                         self.device)
        if case == 'aa_resnet101':
            self.model = aa_resnet101(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.opts.checkpoint).to(
                                          self.device)
        if case == 'aa_resnet152':
            self.model = aa_resnet152(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.opts.checkpoint).to(
                                          self.device)

        # se_resnet
        if case == 'se_resnet50':
            self.model = se_resnet50(pretrained=self.opts.pretrained,
                                     num_classes=self.opts.num_classes,
                                     model_path=self.opts.checkpoint).to(
                                         self.device)
        if case == 'se_resnet101':
            self.model = se_resnet101(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.opts.checkpoint).to(
                                          self.device)
        if case == 'se_resnet152':
            self.model = se_resnet152(pretrained=self.opts.pretrained,
                                      num_classes=self.opts.num_classes,
                                      model_path=self.opts.checkpoint).to(
                                          self.device)
        return self.model