def main(): # path for image img_path = args.img_path if img_path == None: print("you haven't choose any image for prediction!") data_use = img_path.split('/')[-3] class_name = img_path.split('/')[-2] if data_use == '30class_rgb': class_names = aid_class_names elif data_use == '45class_rgb': class_names = nwpu_class_names # choose cnn for prediction if args.net == 1: model = bcnn_vgg.BCNN(class_num=len(class_names), pretrained=None) print("Using model bcnn_vgg for prediction.") elif args.net == 2: # pretrained model needs model = se_resnet.se_resnet50(pretrained=None) model.fc = nn.Linear(2048, len(class_names)) print("Using model se_resnet for prediction.") # continue training from breaking if args.pretrained is not None: print("=> loading pretrained model '{}'".format(args.pretrained)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint) # 加载图像 img = cv2.imread(args.img_path) # 图像预处理 img_transform = transforms.Compose([ MyAugmentations.Resize(224), MyAugmentations.Normalize(mean=dataset_mean, std=dataset_std), MyAugmentations.ToTensor(), ]) img = img[(2, 1, 0), :, :] input = img_transform(img) # 输入网络,获得预测结果 output = model(input.unsqueeze(0)) id = output.argmax(dim=1) print("图像类别为:"+class_name) print("图像预测类别为:"+class_names[id])
def train(args): if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if args.augmix: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((args.img_size), scale=(0.5, 2.0)), ]) elif args.speckle: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((args.img_size), scale=(0.5, 2.0)), transforms.ToTensor(), transforms.RandomApply( [transforms.Lambda(lambda x: speckle_noise_torch(x))], p=0.5), normalize, ]) else: train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((args.img_size), scale=(0.5, 2.0)), transforms.ToTensor(), normalize, ]) if args.cutout: train_transform.transforms.append(transforms.RandomErasing()) val_transform = transforms.Compose([ transforms.Scale((args.img_size, args.img_size)), transforms.ToTensor(), normalize, ]) label_transform = transforms.Compose([ ToLabel(), ]) print("Loading Data") if args.dataset == "deepfashion2": loader = fashion2loader( "../", transform=train_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) if args.augmix: loader = AugMix(loader, args.augmix) if args.stylize: style_loader = fashion2loader( root="../../stylize-datasets/output/", transform=train_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) loader = torch.utils.data.ConcatDataset([loader, style_loader]) valloader = fashion2loader( "../", split="validation", transform=val_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, ) elif args.dataset == "deepaugment": loader = fashion2loader( "../", transform=train_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) loader1 = fashion2loader( root="../../deepaugment/EDSR/", transform=train_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) loader2 = fashion2loader( root="../../deepaugment/CAE/", transform=train_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) loader = torch.utils.data.ConcatDataset([loader, loader1, loader2]) if args.augmix: loader = AugMix(loader, args.augmix) if args.stylize: style_loader = fashion2loader( root="../../stylize-datasets/output/", transform=train_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) loader = torch.utils.data.ConcatDataset([loader, style_loader]) valloader = fashion2loader( "../", split="validation", transform=val_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=(True,True,True,True), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, ) else: raise AssertionError print("Loading Done") n_classes = args.num_classes train_loader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True, shuffle=True) print("number of images = ", len(train_loader)) print("number of classes = ", n_classes) print("Loading arch = ", args.arch) if args.arch == "resnet101": orig_resnet = torchvision.models.resnet101(pretrained=True) features = list(orig_resnet.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "resnet50": orig_resnet = torchvision.models.resnet50(pretrained=True) features = list(orig_resnet.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "resnet152": orig_resnet = torchvision.models.resnet152(pretrained=True) features = list(orig_resnet.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "se": model = se_resnet50(pretrained=True) features = list(model.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "BiT-M-R50x1": model = bit_models.KNOWN_MODELS[args.arch](head_size=2048, zero_head=True) model.load_from(np.load(f"{args.arch}.npz")) features = list(model.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "BiT-M-R101x1": model = bit_models.KNOWN_MODELS[args.arch](head_size=2048, zero_head=True) model.load_from(np.load(f"{args.arch}.npz")) features = list(model.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) if args.load == 1: model.load_state_dict( torch.load(args.save_dir + args.arch + str(args.disc) + ".pth")) clsfier.load_state_dict( torch.load(args.save_dir + args.arch + "clssegsimp" + str(args.disc) + ".pth")) gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') use_dataparallel = len(gpu_ids) > 1 print("using data parallel = ", use_dataparallel, device, gpu_ids) if use_dataparallel: gpu_ids = [int(x) for x in range(len(gpu_ids))] model = nn.DataParallel(model, device_ids=gpu_ids) clsfier = nn.DataParallel(clsfier, device_ids=gpu_ids) model.to(device) clsfier.to(device) if args.finetune: if args.opt == "adam": optimizer = torch.optim.Adam([{ 'params': clsfier.parameters() }], lr=args.lr) else: optimizer = torch.optim.SGD(clsfier.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) else: if args.opt == "adam": optimizer = torch.optim.Adam([{ 'params': model.parameters(), 'lr': args.lr / 10 }, { 'params': clsfier.parameters() }], lr=args.lr) else: optimizer = torch.optim.SGD(itertools.chain( model.parameters(), clsfier.parameters()), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) def cosine_annealing(step, total_steps, lr_max, lr_min): return lr_min + (lr_max - lr_min) * 0.5 * ( 1 + np.cos(step / total_steps * np.pi)) if args.use_scheduler: scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.n_epochs * len(train_loader), 1, # since lr_lambda computes multiplicative factor 1e-6 / (args.lr * args.batch_size / 256.))) bceloss = nn.BCEWithLogitsLoss() for epoch in range(args.n_epochs): for i, (images, labels) in enumerate(tqdm(train_loader)): if args.augmix: x_mix1, x_orig = images images = torch.cat((x_mix1, x_orig), 0).to(device) else: images = images[0].to(device) labels = labels.to(device).float() optimizer.zero_grad() outputs = model(images) outputs = clsfier(outputs) if args.augmix: l_mix1, outputs = torch.split(outputs, x_orig.size(0)) if args.loss == "bce": if args.augmix: if random.random() > 0.5: loss = bceloss(outputs, labels) else: loss = bceloss(l_mix1, labels) else: loss = bceloss(outputs, labels) else: print("Invalid loss please use --loss bce") exit() loss.backward() optimizer.step() if args.use_scheduler: scheduler.step() print(len(train_loader)) print("Epoch [%d/%d] Loss: %.4f" % (epoch + 1, args.n_epochs, loss.data)) save_root = os.path.join(args.save_dir, args.arch) if not os.path.exists(save_root): os.makedirs(save_root) if use_dataparallel: torch.save(model.module.state_dict(), os.path.join(save_root, str(args.disc) + ".pth")) torch.save( clsfier.module.state_dict(), os.path.join(save_root, "clssegsimp" + str(args.disc) + ".pth")) else: torch.save(model.state_dict(), os.path.join(save_root, str(args.disc) + ".pth")) torch.save( clsfier.state_dict(), os.path.join(save_root, 'clssegsimp' + str(args.disc) + ".pth"))
def validate(args): # Setup Dataloader normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) val_transform = transforms.Compose([ transforms.Scale((args.img_size, args.img_size)), transforms.ToTensor(), normalize, ]) label_transform = transforms.Compose([ ToLabel(), # normalize, ]) if args.dataset == "deepfashion2": if not args.concat_data: valloader = fashion2loader( "../", split="validation", transform=val_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, ) else: # lets concat train and val for appropriate labels loader1 = fashion2loader( "../", transform=val_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=True, scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, #load=True, ) loader2 = fashion2loader( "../", split="validation", transform=val_transform, label_transform=label_transform, #scales=(-1), occlusion=(-1), zoom=(-1), viewpoint=(-1), negate=True, scales=args.scales, occlusion=args.occlusion, zoom=args.zoom, viewpoint=args.viewpoint, negate=args.negate, ) valloader = torch.utils.data.ConcatDataset([loader1, loader2]) else: raise AssertionError n_classes = args.num_classes valloader = data.DataLoader(valloader, batch_size=args.batch_size, num_workers=4, shuffle=False) print("Number of samples = ", len(valloader)) print("Loading arch = ", args.arch) if args.arch == 'resnet101': orig_resnet = torchvision.models.resnet101(pretrained=True) features = list(orig_resnet.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == 'resnet50': orig_resnet = torchvision.models.resnet50(pretrained=True) features = list(orig_resnet.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == 'resnet152': orig_resnet = torchvision.models.resnet152(pretrained=True) features = list(orig_resnet.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == 'se': model = se_resnet50(pretrained=True) features = list(model.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "BiT-M-R50x1": model = bit_models.KNOWN_MODELS[args.arch](head_size=2048, zero_head=True) model.load_from(np.load(f"{args.arch}.npz")) features = list(model.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) elif args.arch == "BiT-M-R101x1": model = bit_models.KNOWN_MODELS[args.arch](head_size=2048, zero_head=True) model.load_from(np.load(f"{args.arch}.npz")) features = list(model.children()) model = nn.Sequential(*features[0:8]) clsfier = clssimp(2048, n_classes) model.load_state_dict( torch.load(args.save_dir + args.arch + "/" + str(args.disc) + ".pth")) clsfier.load_state_dict( torch.load(args.save_dir + args.arch + "/" + 'clssegsimp' + str(args.disc) + ".pth")) model.eval() clsfier.eval() if torch.cuda.is_available(): model.cuda(0) clsfier.cuda(0) model.eval() gts = {i: [] for i in range(0, n_classes)} preds = {i: [] for i in range(0, n_classes)} # gts, preds = [], [] for i, (images, labels) in tqdm(enumerate(valloader)): images = images[0].cuda() labels = labels.cuda().float() outputs = model(images) outputs = clsfier(outputs) outputs = F.sigmoid(outputs) pred = outputs.data.cpu().numpy() gt = labels.data.cpu().numpy() for label in range(0, n_classes): gts[label].extend(gt[:, label]) preds[label].extend(pred[:, label]) FinalMAPs = [] for i in range(0, n_classes): precision, recall, thresholds = metrics.precision_recall_curve( gts[i], preds[i]) FinalMAPs.append(metrics.auc(recall, precision)) print(FinalMAPs) tmp = [] for i in range(len(gts)): tmp.append(gts[i]) gts = np.array(tmp) FinalMAPs = np.array(FinalMAPs) denom = gts.sum() gts = gts.sum(axis=-1) gts = gts / denom res = np.nan_to_num(FinalMAPs * gts) print((res).sum())
def main(): # For each dataset and ratio in data_use_ratio, train five times for i in range(1, 6): for data_use, ratio in data_use_ratio: if data_use == 'aid': class_names = aid_class_names data_path = "../MINN/datasets/30class_rgb/" elif data_use == 'nwpu': class_names = nwpu_class_names data_path = "../MINN/datasets/45class_rgb/" elif data_use == 'ucm': class_names = ucm_class_names else: print('Please choose datasets for training use!') # Dir to save log file and model parameters logdir = log_dir + '/log_' + data_use save_dir = log_dir + '/save_' + data_use if not os.path.exists(logdir): os.makedirs(logdir) if not os.path.exists(save_dir): os.makedirs(save_dir) TFwriter = SummaryWriter( logdir) # Save loss and acc for Visualization log_file = open( logdir + '/log_' + data_use + ratio + '_' + str(i) + '.txt', 'w') log_file.write('datasets:' + data_use) print('datasets:' + data_use) log_file.write('\nratio:' + ratio) print('ratio:' + ratio) log_file.write('\nepochs:' + str(args.epochs)) print('epochs:' + str(args.epochs)) log_file.write('\nlearning rate:' + str(args.lr)) print('learning rate:' + str(args.lr)) log_file.write('\nbatch size:' + str(args.batch_size)) print('batch size:' + str(args.batch_size)) if args.net == 1: model = bcnn_vgg.BCNN(class_num=len(class_names), pretrained=None) print("Using model bcnn_vgg for traning.") log_file.write("\nUsing model bcnn_vgg for traning.") elif args.net == 2: # pretrained model needs model = se_resnet.se_resnet50(pretrained=None) model.fc = nn.Linear(2048, len(class_names)) print("Using model se_resnet for traning.") log_file.write("\nUsing model se_resnet for traning.") # Using GPUs and cuda to accelerate training if torch.cuda.is_available(): # using gpu for training model.cuda() cudnn.benchmark = True # continue training from breaking if args.resume is not None: print("=> loading pretrained model '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint) # Loading RS dataset print("Loading dataset...") train_file = 'dir_file/' + data_use + '_train' + ratio + '_' + str( i) + '.txt' test_file = 'dir_file/' + data_use + '_test' + ratio + '_' + str( i) + '.txt' train_loader = torch.utils.data.DataLoader( CLSDataPrepare(root=data_path, txt_path=train_file, img_transform=MyAugmentations.TrainAugmentation( size=224, _mean=dataset_mean, _std=dataset_std)), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, collate_fn=classifier_collate) val_loader = torch.utils.data.DataLoader( CLSDataPrepare(root=data_path, txt_path=test_file, img_transform=MyAugmentations.TestAugmentation( size=224, _mean=dataset_mean, _std=dataset_std)), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=classifier_collate) # define loss function (criterion) ,optimizer and adjust learning rate step criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) # Training and Validation, save the best model best_prec = 0 print("Start training...") for epoch in range(args.start_epoch, args.epochs): scheduler.step() # train for one epoch start = time.time() train(train_loader, model, criterion, optimizer, epoch, TFwriter) # validate model prec1, test_loss = validate(val_loader, model, criterion, len(class_names)) end = time.time() print("time for one epoch:%.2fmin" % ((end - start) / 60)) # OA, Kappa, class_specific_PA, class_specific_UA = get_OAKappa_by_conf(Confusion_Matrix) TFwriter.add_scalar('#test_loss', test_loss, epoch) TFwriter.add_scalar('#accuracy', prec1, epoch) print('after %d epochs,accuracy = %f, test_loss = %f' % (epoch, prec1, test_loss)) message = '\nafter {} epochs,accuracy = {:.2f}, test_loss = {:.8f}'.format( epoch, prec1, test_loss) log_file.write(message) # remember best prec@1 and save checkpoint if prec1 > best_prec: best_prec = prec1 torch.save( model.state_dict(), os.path.join(save_dir, 'checkpoint_{}_{}.pth'.format(ratio, i))) print(best_prec)
from torch import nn from collections import OrderedDict from model.se_resnet import se_resnet50 from torch.autograd import Variable class Equal(nn.Module): def __init__(self, x): self.x = x def forward(self, x): return x num_class = 20 model = se_resnet50(num_classes=1000) data = torch.load("/home/lxt/Github/games/model/pretrained/weight-99.pkl") state_dict = torch.load( "/home/lxt/Github/games/model/pretrained/weight-99.pkl")["weight"] new_state_dict = OrderedDict() for k, v in state_dict.items(): print(k, v.size()) name = k[7:] new_state_dict[name] = v if name == "fc": break for k, v in new_state_dict.items():