Example #1
0
def main():
    # path for image
    img_path = args.img_path
    if img_path == None:
        print("you haven't choose any image for prediction!")
    data_use = img_path.split('/')[-3]
    class_name = img_path.split('/')[-2]
    if data_use == '30class_rgb':
        class_names = aid_class_names
    elif data_use == '45class_rgb':
        class_names = nwpu_class_names

    # choose cnn for prediction
    if args.net == 1:
        model = bcnn_vgg.BCNN(class_num=len(class_names), pretrained=None)
        print("Using model bcnn_vgg for prediction.")
    elif args.net == 2:
        # pretrained model needs
        model = se_resnet.se_resnet50(pretrained=None)
        model.fc = nn.Linear(2048, len(class_names))
        print("Using model se_resnet for prediction.")

    # continue training from breaking
    if args.pretrained is not None:
        print("=> loading pretrained model '{}'".format(args.pretrained))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint)

    # 加载图像
    img = cv2.imread(args.img_path)

    # 图像预处理
    img_transform = transforms.Compose([
        MyAugmentations.Resize(224),
        MyAugmentations.Normalize(mean=dataset_mean, std=dataset_std),
        MyAugmentations.ToTensor(),
    ])
    img = img[(2, 1, 0), :, :]
    input = img_transform(img)

    # 输入网络,获得预测结果
    output = model(input.unsqueeze(0))
    id = output.argmax(dim=1)
    print("图像类别为:"+class_name)
    print("图像预测类别为:"+class_names[id])
Example #2
0
def train(args):
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.augmix:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop((args.img_size), scale=(0.5, 2.0)),
        ])
    elif args.speckle:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop((args.img_size), scale=(0.5, 2.0)),
            transforms.ToTensor(),
            transforms.RandomApply(
                [transforms.Lambda(lambda x: speckle_noise_torch(x))], p=0.5),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop((args.img_size), scale=(0.5, 2.0)),
            transforms.ToTensor(),
            normalize,
        ])
    if args.cutout:
        train_transform.transforms.append(transforms.RandomErasing())

    val_transform = transforms.Compose([
        transforms.Scale((args.img_size, args.img_size)),
        transforms.ToTensor(),
        normalize,
    ])

    label_transform = transforms.Compose([
        ToLabel(),
    ])
    print("Loading Data")
    if args.dataset == "deepfashion2":
        loader = fashion2loader(
            "../",
            transform=train_transform,
            label_transform=label_transform,
            #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
            scales=args.scales,
            occlusion=args.occlusion,
            zoom=args.zoom,
            viewpoint=args.viewpoint,
            negate=args.negate,
            #load=True,
        )
        if args.augmix:
            loader = AugMix(loader, args.augmix)
        if args.stylize:
            style_loader = fashion2loader(
                root="../../stylize-datasets/output/",
                transform=train_transform,
                label_transform=label_transform,
                #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
                scales=args.scales,
                occlusion=args.occlusion,
                zoom=args.zoom,
                viewpoint=args.viewpoint,
                negate=args.negate,
                #load=True,
            )
            loader = torch.utils.data.ConcatDataset([loader, style_loader])
        valloader = fashion2loader(
            "../",
            split="validation",
            transform=val_transform,
            label_transform=label_transform,
            #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
            scales=args.scales,
            occlusion=args.occlusion,
            zoom=args.zoom,
            viewpoint=args.viewpoint,
            negate=args.negate,
        )
    elif args.dataset == "deepaugment":
        loader = fashion2loader(
            "../",
            transform=train_transform,
            label_transform=label_transform,
            #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
            scales=args.scales,
            occlusion=args.occlusion,
            zoom=args.zoom,
            viewpoint=args.viewpoint,
            negate=args.negate,
            #load=True,
        )
        loader1 = fashion2loader(
            root="../../deepaugment/EDSR/",
            transform=train_transform,
            label_transform=label_transform,
            #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
            scales=args.scales,
            occlusion=args.occlusion,
            zoom=args.zoom,
            viewpoint=args.viewpoint,
            negate=args.negate,
            #load=True,
        )
        loader2 = fashion2loader(
            root="../../deepaugment/CAE/",
            transform=train_transform,
            label_transform=label_transform,
            #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
            scales=args.scales,
            occlusion=args.occlusion,
            zoom=args.zoom,
            viewpoint=args.viewpoint,
            negate=args.negate,
            #load=True,
        )
        loader = torch.utils.data.ConcatDataset([loader, loader1, loader2])
        if args.augmix:
            loader = AugMix(loader, args.augmix)
        if args.stylize:
            style_loader = fashion2loader(
                root="../../stylize-datasets/output/",
                transform=train_transform,
                label_transform=label_transform,
                #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
                scales=args.scales,
                occlusion=args.occlusion,
                zoom=args.zoom,
                viewpoint=args.viewpoint,
                negate=args.negate,
                #load=True,
            )
            loader = torch.utils.data.ConcatDataset([loader, style_loader])
        valloader = fashion2loader(
            "../",
            split="validation",
            transform=val_transform,
            label_transform=label_transform,
            #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True),
            scales=args.scales,
            occlusion=args.occlusion,
            zoom=args.zoom,
            viewpoint=args.viewpoint,
            negate=args.negate,
        )

    else:
        raise AssertionError
    print("Loading Done")

    n_classes = args.num_classes
    train_loader = data.DataLoader(loader,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   drop_last=True,
                                   shuffle=True)

    print("number of images = ", len(train_loader))
    print("number of classes = ", n_classes)

    print("Loading arch = ", args.arch)
    if args.arch == "resnet101":
        orig_resnet = torchvision.models.resnet101(pretrained=True)
        features = list(orig_resnet.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "resnet50":
        orig_resnet = torchvision.models.resnet50(pretrained=True)
        features = list(orig_resnet.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "resnet152":
        orig_resnet = torchvision.models.resnet152(pretrained=True)
        features = list(orig_resnet.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "se":
        model = se_resnet50(pretrained=True)
        features = list(model.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "BiT-M-R50x1":
        model = bit_models.KNOWN_MODELS[args.arch](head_size=2048,
                                                   zero_head=True)
        model.load_from(np.load(f"{args.arch}.npz"))
        features = list(model.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "BiT-M-R101x1":
        model = bit_models.KNOWN_MODELS[args.arch](head_size=2048,
                                                   zero_head=True)
        model.load_from(np.load(f"{args.arch}.npz"))
        features = list(model.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)

    if args.load == 1:
        model.load_state_dict(
            torch.load(args.save_dir + args.arch + str(args.disc) + ".pth"))
        clsfier.load_state_dict(
            torch.load(args.save_dir + args.arch + "clssegsimp" +
                       str(args.disc) + ".pth"))

    gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    use_dataparallel = len(gpu_ids) > 1
    print("using data parallel = ", use_dataparallel, device, gpu_ids)
    if use_dataparallel:
        gpu_ids = [int(x) for x in range(len(gpu_ids))]
        model = nn.DataParallel(model, device_ids=gpu_ids)
        clsfier = nn.DataParallel(clsfier, device_ids=gpu_ids)
    model.to(device)
    clsfier.to(device)

    if args.finetune:
        if args.opt == "adam":
            optimizer = torch.optim.Adam([{
                'params': clsfier.parameters()
            }],
                                         lr=args.lr)
        else:
            optimizer = torch.optim.SGD(clsfier.parameters(),
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        nesterov=True)
    else:
        if args.opt == "adam":
            optimizer = torch.optim.Adam([{
                'params': model.parameters(),
                'lr': args.lr / 10
            }, {
                'params': clsfier.parameters()
            }],
                                         lr=args.lr)
        else:
            optimizer = torch.optim.SGD(itertools.chain(
                model.parameters(), clsfier.parameters()),
                                        args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        nesterov=True)

    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))

    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                args.n_epochs * len(train_loader),
                1,  # since lr_lambda computes multiplicative factor
                1e-6 / (args.lr * args.batch_size / 256.)))

    bceloss = nn.BCEWithLogitsLoss()
    for epoch in range(args.n_epochs):
        for i, (images, labels) in enumerate(tqdm(train_loader)):
            if args.augmix:
                x_mix1, x_orig = images
                images = torch.cat((x_mix1, x_orig), 0).to(device)
            else:
                images = images[0].to(device)
            labels = labels.to(device).float()

            optimizer.zero_grad()

            outputs = model(images)
            outputs = clsfier(outputs)
            if args.augmix:
                l_mix1, outputs = torch.split(outputs, x_orig.size(0))

            if args.loss == "bce":
                if args.augmix:
                    if random.random() > 0.5:
                        loss = bceloss(outputs, labels)
                    else:
                        loss = bceloss(l_mix1, labels)
                else:
                    loss = bceloss(outputs, labels)
            else:
                print("Invalid loss please use --loss bce")
                exit()

            loss.backward()
            optimizer.step()
            if args.use_scheduler:
                scheduler.step()

        print(len(train_loader))
        print("Epoch [%d/%d] Loss: %.4f" %
              (epoch + 1, args.n_epochs, loss.data))

        save_root = os.path.join(args.save_dir, args.arch)
        if not os.path.exists(save_root):
            os.makedirs(save_root)
        if use_dataparallel:
            torch.save(model.module.state_dict(),
                       os.path.join(save_root,
                                    str(args.disc) + ".pth"))
            torch.save(
                clsfier.module.state_dict(),
                os.path.join(save_root,
                             "clssegsimp" + str(args.disc) + ".pth"))
        else:
            torch.save(model.state_dict(),
                       os.path.join(save_root,
                                    str(args.disc) + ".pth"))
            torch.save(
                clsfier.state_dict(),
                os.path.join(save_root,
                             'clssegsimp' + str(args.disc) + ".pth"))
Example #3
0
def validate(args):
    # Setup Dataloader
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    val_transform = transforms.Compose([
        transforms.Scale((args.img_size, args.img_size)),
        transforms.ToTensor(),
        normalize,
    ])

    label_transform = transforms.Compose([
        ToLabel(),
        # normalize,
    ])
    if args.dataset == "deepfashion2":
        if not args.concat_data:
            valloader = fashion2loader(
                "../",
                split="validation",
                transform=val_transform,
                label_transform=label_transform,
                #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1),
                scales=args.scales,
                occlusion=args.occlusion,
                zoom=args.zoom,
                viewpoint=args.viewpoint,
                negate=args.negate,
            )
        else:  # lets concat train and val for appropriate labels
            loader1 = fashion2loader(
                "../",
                transform=val_transform,
                label_transform=label_transform,
                #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=True,
                scales=args.scales,
                occlusion=args.occlusion,
                zoom=args.zoom,
                viewpoint=args.viewpoint,
                negate=args.negate,
                #load=True,
            )
            loader2 = fashion2loader(
                "../",
                split="validation",
                transform=val_transform,
                label_transform=label_transform,
                #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=True,
                scales=args.scales,
                occlusion=args.occlusion,
                zoom=args.zoom,
                viewpoint=args.viewpoint,
                negate=args.negate,
            )
            valloader = torch.utils.data.ConcatDataset([loader1, loader2])
    else:
        raise AssertionError

    n_classes = args.num_classes
    valloader = data.DataLoader(valloader,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False)

    print("Number of samples = ", len(valloader))
    print("Loading arch = ", args.arch)

    if args.arch == 'resnet101':
        orig_resnet = torchvision.models.resnet101(pretrained=True)
        features = list(orig_resnet.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == 'resnet50':
        orig_resnet = torchvision.models.resnet50(pretrained=True)
        features = list(orig_resnet.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == 'resnet152':
        orig_resnet = torchvision.models.resnet152(pretrained=True)
        features = list(orig_resnet.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == 'se':
        model = se_resnet50(pretrained=True)
        features = list(model.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "BiT-M-R50x1":
        model = bit_models.KNOWN_MODELS[args.arch](head_size=2048,
                                                   zero_head=True)
        model.load_from(np.load(f"{args.arch}.npz"))
        features = list(model.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)
    elif args.arch == "BiT-M-R101x1":
        model = bit_models.KNOWN_MODELS[args.arch](head_size=2048,
                                                   zero_head=True)
        model.load_from(np.load(f"{args.arch}.npz"))
        features = list(model.children())
        model = nn.Sequential(*features[0:8])
        clsfier = clssimp(2048, n_classes)

    model.load_state_dict(
        torch.load(args.save_dir + args.arch + "/" + str(args.disc) + ".pth"))
    clsfier.load_state_dict(
        torch.load(args.save_dir + args.arch + "/" + 'clssegsimp' +
                   str(args.disc) + ".pth"))

    model.eval()
    clsfier.eval()

    if torch.cuda.is_available():
        model.cuda(0)
        clsfier.cuda(0)

    model.eval()
    gts = {i: [] for i in range(0, n_classes)}
    preds = {i: [] for i in range(0, n_classes)}
    # gts, preds = [], []
    for i, (images, labels) in tqdm(enumerate(valloader)):
        images = images[0].cuda()
        labels = labels.cuda().float()

        outputs = model(images)
        outputs = clsfier(outputs)
        outputs = F.sigmoid(outputs)
        pred = outputs.data.cpu().numpy()
        gt = labels.data.cpu().numpy()

        for label in range(0, n_classes):
            gts[label].extend(gt[:, label])
            preds[label].extend(pred[:, label])

    FinalMAPs = []
    for i in range(0, n_classes):
        precision, recall, thresholds = metrics.precision_recall_curve(
            gts[i], preds[i])
        FinalMAPs.append(metrics.auc(recall, precision))
    print(FinalMAPs)
    tmp = []
    for i in range(len(gts)):
        tmp.append(gts[i])
    gts = np.array(tmp)

    FinalMAPs = np.array(FinalMAPs)
    denom = gts.sum()
    gts = gts.sum(axis=-1)
    gts = gts / denom
    res = np.nan_to_num(FinalMAPs * gts)
    print((res).sum())
Example #4
0
def main():
    # For each dataset and ratio in data_use_ratio, train five times
    for i in range(1, 6):
        for data_use, ratio in data_use_ratio:
            if data_use == 'aid':
                class_names = aid_class_names
                data_path = "../MINN/datasets/30class_rgb/"
            elif data_use == 'nwpu':
                class_names = nwpu_class_names
                data_path = "../MINN/datasets/45class_rgb/"
            elif data_use == 'ucm':
                class_names = ucm_class_names
            else:
                print('Please choose datasets for training use!')

            # Dir to save log file and model parameters
            logdir = log_dir + '/log_' + data_use
            save_dir = log_dir + '/save_' + data_use
            if not os.path.exists(logdir):
                os.makedirs(logdir)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            TFwriter = SummaryWriter(
                logdir)  # Save loss and acc for Visualization

            log_file = open(
                logdir + '/log_' + data_use + ratio + '_' + str(i) + '.txt',
                'w')
            log_file.write('datasets:' + data_use)
            print('datasets:' + data_use)
            log_file.write('\nratio:' + ratio)
            print('ratio:' + ratio)
            log_file.write('\nepochs:' + str(args.epochs))
            print('epochs:' + str(args.epochs))
            log_file.write('\nlearning rate:' + str(args.lr))
            print('learning rate:' + str(args.lr))
            log_file.write('\nbatch size:' + str(args.batch_size))
            print('batch size:' + str(args.batch_size))

            if args.net == 1:
                model = bcnn_vgg.BCNN(class_num=len(class_names),
                                      pretrained=None)
                print("Using model bcnn_vgg for traning.")
                log_file.write("\nUsing model bcnn_vgg for traning.")
            elif args.net == 2:
                # pretrained model needs
                model = se_resnet.se_resnet50(pretrained=None)
                model.fc = nn.Linear(2048, len(class_names))
                print("Using model se_resnet for traning.")
                log_file.write("\nUsing model se_resnet for traning.")

            # Using GPUs and cuda to accelerate training
            if torch.cuda.is_available():
                # using gpu for training
                model.cuda()
                cudnn.benchmark = True

            # continue training from breaking
            if args.resume is not None:
                print("=> loading pretrained model '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                model.load_state_dict(checkpoint)

            # Loading RS dataset
            print("Loading dataset...")
            train_file = 'dir_file/' + data_use + '_train' + ratio + '_' + str(
                i) + '.txt'
            test_file = 'dir_file/' + data_use + '_test' + ratio + '_' + str(
                i) + '.txt'
            train_loader = torch.utils.data.DataLoader(
                CLSDataPrepare(root=data_path,
                               txt_path=train_file,
                               img_transform=MyAugmentations.TrainAugmentation(
                                   size=224,
                                   _mean=dataset_mean,
                                   _std=dataset_std)),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                collate_fn=classifier_collate)

            val_loader = torch.utils.data.DataLoader(
                CLSDataPrepare(root=data_path,
                               txt_path=test_file,
                               img_transform=MyAugmentations.TestAugmentation(
                                   size=224,
                                   _mean=dataset_mean,
                                   _std=dataset_std)),
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                collate_fn=classifier_collate)

            # define loss function (criterion) ,optimizer and adjust learning rate step
            criterion = nn.CrossEntropyLoss().cuda()
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

            # Training and Validation, save the best model
            best_prec = 0
            print("Start training...")
            for epoch in range(args.start_epoch, args.epochs):
                scheduler.step()
                # train for one epoch
                start = time.time()
                train(train_loader, model, criterion, optimizer, epoch,
                      TFwriter)
                # validate model
                prec1, test_loss = validate(val_loader, model, criterion,
                                            len(class_names))
                end = time.time()
                print("time for one epoch:%.2fmin" % ((end - start) / 60))

                # OA, Kappa, class_specific_PA, class_specific_UA = get_OAKappa_by_conf(Confusion_Matrix)
                TFwriter.add_scalar('#test_loss', test_loss, epoch)
                TFwriter.add_scalar('#accuracy', prec1, epoch)
                print('after %d epochs,accuracy = %f, test_loss = %f' %
                      (epoch, prec1, test_loss))
                message = '\nafter {} epochs,accuracy = {:.2f}, test_loss = {:.8f}'.format(
                    epoch, prec1, test_loss)
                log_file.write(message)

                # remember best prec@1 and save checkpoint
                if prec1 > best_prec:
                    best_prec = prec1
                    torch.save(
                        model.state_dict(),
                        os.path.join(save_dir,
                                     'checkpoint_{}_{}.pth'.format(ratio, i)))
            print(best_prec)
Example #5
0
from torch import nn
from collections import OrderedDict
from model.se_resnet import se_resnet50
from torch.autograd import Variable


class Equal(nn.Module):
    def __init__(self, x):
        self.x = x

    def forward(self, x):
        return x


num_class = 20
model = se_resnet50(num_classes=1000)

data = torch.load("/home/lxt/Github/games/model/pretrained/weight-99.pkl")
state_dict = torch.load(
    "/home/lxt/Github/games/model/pretrained/weight-99.pkl")["weight"]

new_state_dict = OrderedDict()

for k, v in state_dict.items():
    print(k, v.size())
    name = k[7:]
    new_state_dict[name] = v
    if name == "fc":
        break

for k, v in new_state_dict.items():