def main(): global args global sv_name_eval # save configuration to file sv_name = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S') sv_name_eval = sv_name print('saving file name is ', sv_name) write_arguments_to_file(args, os.path.join(logs_dir, sv_name + '_arguments.txt')) # ----------------------------------- data # define mean/std of the training set (for data normalization) label_type = args.label_type use_s1 = (args.sensor_type == 's1') | (args.sensor_type == 's1s2') use_s2 = (args.sensor_type == 's2') | (args.sensor_type == 's1s2') dataset = args.dataset data_dir = os.path.join("data", dataset, "data") bands_mean = {} bands_std = {} train_dataGen = None val_dataGen = None test_dataGen = None print(f"Using {dataset} dataset") if dataset == 'sen12ms': bands_mean = { 's1_mean': [-11.76858, -18.294598], 's2_mean': [ 1226.4215, 1137.3799, 1139.6792, 1350.9973, 1932.9058, 2211.1584, 2154.9846, 2409.1128, 2001.8622, 1356.0801 ] } bands_std = { 's1_std': [4.525339, 4.3586307], 's2_std': [ 741.6254, 740.883, 960.1045, 946.76056, 985.52747, 1082.4341, 1057.7628, 1136.1942, 1132.7898, 991.48016 ] } elif dataset == 'bigearthnet': # THE S2 BAND STATISTICS WERE PROVIDED BY THE BIGEARTHNET TEAM # Source: https://git.tu-berlin.de/rsim/bigearthnet-models-tf/-/blob/master/BigEarthNet.py bands_mean = { 's1_mean': [-12.619993, -19.290445], 's2_mean': [ 340.76769064, 429.9430203, 614.21682446, 590.23569706, 950.68368468, 1792.46290469, 2075.46795189, 2218.94553375, 2266.46036911, 2246.0605464, 1594.42694882, 1009.32729131 ] } bands_std = { 's1_std': [5.115911, 5.464428], 's2_std': [ 554.81258967, 572.41639287, 582.87945694, 675.88746967, 729.89827633, 1096.01480586, 1273.45393088, 1365.45589904, 1356.13789355, 1302.3292881, 1079.19066363, 818.86747235 ] } else: raise NameError(f"unknown dataset: {dataset}") # load datasets imgTransform = transforms.Compose( [ToTensor(), Normalize(bands_mean, bands_std)]) if dataset == 'sen12ms': train_dataGen = SEN12MS(data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="train", use_s1=use_s1, use_s2=use_s2, use_RGB=args.use_RGB, IGBP_s=args.simple_scheme, data_size=args.data_size, sensor_type=args.sensor_type, use_fusion=args.use_fusion) val_dataGen = SEN12MS(data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="val", use_s1=use_s1, use_s2=use_s2, use_RGB=args.use_RGB, IGBP_s=args.simple_scheme, data_size=args.data_size, sensor_type=args.sensor_type, use_fusion=args.use_fusion) if args.eval: test_dataGen = SEN12MS(data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="test", use_s1=use_s1, use_s2=use_s2, use_RGB=args.use_RGB, IGBP_s=args.simple_scheme, sensor_type=args.sensor_type, use_fusion=args.use_fusion) else: # Assume bigearthnet train_dataGen = BigEarthNet(data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="train", use_s1=use_s1, use_s2=use_s2, use_RGB=args.use_RGB, CLC_s=args.simple_scheme, data_size=args.data_size, sensor_type=args.sensor_type, use_fusion=args.use_fusion) val_dataGen = BigEarthNet(data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="val", use_s1=use_s1, use_s2=use_s2, use_RGB=args.use_RGB, CLC_s=args.simple_scheme, data_size=args.data_size, sensor_type=args.sensor_type, use_fusion=args.use_fusion) if args.eval: test_dataGen = BigEarthNet(data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="test", use_s1=use_s1, use_s2=use_s2, use_RGB=args.use_RGB, CLC_s=args.simple_scheme, sensor_type=args.sensor_type, use_fusion=args.use_fusion) # number of input channels n_inputs = train_dataGen.n_inputs print('input channels =', n_inputs) wandb.config.update({"input_channels": n_inputs}) # set up dataloaders train_data_loader = DataLoader(train_dataGen, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True) val_data_loader = DataLoader(val_dataGen, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) if args.eval: test_data_loader = DataLoader(test_dataGen, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) # -------------------------------- ML setup # cuda use_cuda = torch.cuda.is_available() if use_cuda: torch.backends.cudnn.enabled = True cudnn.benchmark = True # define number of classes if dataset == 'sen12ms': if args.simple_scheme: numCls = 10 ORG_LABELS = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] else: numCls = 17 ORG_LABELS = [ '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17' ] else: if args.simple_scheme: numCls = 19 ORG_LABELS = [ '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19' ] else: numCls = 43 ORG_LABELS = [ '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43' ] print('num_class: ', numCls) wandb.config.update({"n_class": numCls}) # define model if args.model == 'VGG16': model = VGG16(n_inputs, numCls) elif args.model == 'VGG19': model = VGG19(n_inputs, numCls) elif args.model == 'Supervised': model = ResNet50(n_inputs, numCls) elif args.model == 'Supervised_1x1': model = ResNet50_1x1(n_inputs, numCls) elif args.model == 'ResNet101': model = ResNet101(n_inputs, numCls) elif args.model == 'ResNet152': model = ResNet152(n_inputs, numCls) elif args.model == 'DenseNet121': model = DenseNet121(n_inputs, numCls) elif args.model == 'DenseNet161': model = DenseNet161(n_inputs, numCls) elif args.model == 'DenseNet169': model = DenseNet169(n_inputs, numCls) elif args.model == 'DenseNet201': model = DenseNet201(n_inputs, numCls) # finetune moco pre-trained model elif args.model.startswith("Moco"): pt_path = os.path.join(args.pt_dir, f"{args.pt_name}.pth") print(pt_path) assert os.path.exists(pt_path) if args.model == 'Moco': print("transfer backbone weights but no conv 1x1 input module") model = Moco(torch.load(pt_path), n_inputs, numCls) elif args.model == 'Moco_1x1': print("transfer backbone weights and input module weights") model = Moco_1x1(torch.load(pt_path), n_inputs, numCls) elif args.model == 'Moco_1x1RND': print( "transfer backbone weights but initialize input module random with random weights" ) model = Moco_1x1(torch.load(pt_path), n_inputs, numCls) else: # Assume Moco2 at present raise NameError("no model") else: raise NameError("no model") print(model) # move model to GPU if is available if use_cuda: model = model.cuda() # define loss function if label_type == 'multi_label': lossfunc = torch.nn.BCEWithLogitsLoss() else: lossfunc = torch.nn.CrossEntropyLoss() # set up optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) best_acc = 0 start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) checkpoint_nm = os.path.basename(args.resume) sv_name = checkpoint_nm.split('_')[0] + '_' + checkpoint_nm.split( '_')[1] print('saving file name is ', sv_name) if checkpoint['epoch'] > start_epoch: start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_prec'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # set up tensorboard logging # train_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'training')) # val_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'val')) # ----------------------------- executing Train/Val. # train network # wandb.watch(model, log="all") scheduler = None if args.use_lr_step: # Ex: If initial Lr is 0.0001, step size is 25, and gamma is 0.1, then lr will be changed for every 20 steps # 0.0001 - first 25 epochs # 0.00001 - 25 to 50 epochs # 0.000001 - 50 to 75 epochs # 0.0000001 - 75 to 100 epochs # https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_step_gamma) for epoch in range(start_epoch, args.epochs): if args.use_lr_step: scheduler.step() print('Epoch {}/{} lr: {}'.format(epoch, args.epochs - 1, optimizer.param_groups[0]['lr'])) else: print('Epoch {}/{}'.format(epoch, args.epochs - 1)) print('-' * 25) train(train_data_loader, model, optimizer, lossfunc, label_type, epoch, use_cuda) micro_f1 = val(val_data_loader, model, optimizer, label_type, epoch, use_cuda) is_best_acc = micro_f1 > best_acc best_acc = max(best_acc, micro_f1) save_checkpoint( { 'epoch': epoch, 'arch': args.model, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_prec': best_acc }, is_best_acc, sv_name) wandb.log({'epoch': epoch, 'micro_f1': micro_f1}) print("=============") print("done training") print("=============") if args.eval: eval(test_data_loader, model, label_type, numCls, use_cuda, ORG_LABELS)
def choose_nets(nets_name, num_classes=100): nets_name = nets_name.lower() if nets_name == 'vgg11': from models.VGG import VGG11 return VGG11(num_classes) if nets_name == 'vgg13': from models.VGG import VGG13 return VGG13(num_classes) if nets_name == 'VGG16': from models.VGG import VGG16 return VGG16(num_classes) if nets_name == 'vgg19': from models.VGG import VGG19 return VGG19(num_classes) if nets_name == 'resnet18': from models.ResNet import ResNet18 return ResNet18(num_classes) if nets_name == 'resnet34': from models.ResNet import ResNet34 return ResNet34(num_classes) if nets_name == 'resnet50': from models.ResNet import ResNet50 return ResNet50(num_classes) if nets_name == 'resnet101': from models.ResNet import ResNet101 return ResNet101(num_classes) if nets_name == 'resnet152': from models.ResNet import ResNet152 return ResNet152(num_classes) if nets_name == 'googlenet': from models.GoogLeNet import GoogLeNet return GoogLeNet(num_classes) if nets_name == 'inceptionv3': from models.InceptionV3 import inceptionv3 return inceptionv3(num_classes) if nets_name == 'mobilenet': from models.MobileNet import mobilenet return mobilenet(num_classes) if nets_name == 'mobilenetv2': from models.MobileNetV2 import mobilenetv2 return mobilenetv2(num_classes) if nets_name == 'seresnet18': from models.SEResNet import seresnet18 return seresnet18(num_classes) if nets_name == 'seresnet34': from models.SEResNet import seresnet34 return seresnet34(num_classes) if nets_name == 'seresnet50': from models.SEResNet import seresnet50 return seresnet50(num_classes) if nets_name == 'seresnet101': from models.SEResNet import seresnet101 return seresnet101(num_classes) if nets_name == 'seresnet152': from models.SEResNet import seresnet152 return seresnet152(num_classes) if nets_name == 'densenet121': from models.DenseNet import densenet121 return densenet121(num_classes) if nets_name == 'densenet169': from models.DenseNet import densenet169 return densenet169(num_classes) if nets_name == 'densenet201': from models.DenseNet import densenet201 return densenet201(num_classes) if nets_name == 'densenet121': from models.DenseNet import densenet161 return densenet161(num_classes) if nets_name == 'squeezenet': from models.SqueezeNet import squeezenet return squeezenet(num_classes) if nets_name == 'inceptionv4': from models.InceptionV4 import inceptionv4 return inceptionv4(num_classes) if nets_name == 'inception-resnet-v2': from models.InceptionV4 import inception_resnet_v2 return inception_resnet_v2(num_classes) raise NotImplementedError
def main(): global args # save configuration to file sv_name = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S') print('saving file name is ', sv_name) write_arguments_to_file(args, os.path.join(logs_dir, sv_name+'_arguments.txt')) # ----------------------------------- data # define mean/std of the training set (for data normalization) label_type = args.label_type bands_mean = {'s1_mean': [-11.76858, -18.294598], 's2_mean': [1226.4215, 1137.3799, 1139.6792, 1350.9973, 1932.9058, 2211.1584, 2154.9846, 2409.1128, 2001.8622, 1356.0801]} bands_std = {'s1_std': [4.525339, 4.3586307], 's2_std': [741.6254, 740.883, 960.1045, 946.76056, 985.52747, 1082.4341, 1057.7628, 1136.1942, 1132.7898, 991.48016]} # load datasets imgTransform = transforms.Compose([ToTensor(),Normalize(bands_mean, bands_std)]) train_dataGen = SEN12MS(args.data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="train", use_s1=args.use_s1, use_s2=args.use_s2, use_RGB=args.use_RGB, IGBP_s=args.IGBP_simple) val_dataGen = SEN12MS(args.data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=label_type, threshold=args.threshold, subset="val", use_s1=args.use_s1, use_s2=args.use_s2, use_RGB=args.use_RGB, IGBP_s=args.IGBP_simple) # number of input channels n_inputs = train_dataGen.n_inputs print('input channels =', n_inputs) # set up dataloaders train_data_loader = DataLoader(train_dataGen, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True) val_data_loader = DataLoader(val_dataGen, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) # -------------------------------- ML setup # cuda use_cuda = torch.cuda.is_available() if use_cuda: torch.backends.cudnn.enabled = True cudnn.benchmark = True # define number of classes if args.IGBP_simple: numCls = 10 else: numCls = 17 print('num_class: ', numCls) # define model if args.model == 'VGG16': model = VGG16(n_inputs, numCls) elif args.model == 'VGG19': model = VGG19(n_inputs, numCls) elif args.model == 'ResNet50': model = ResNet50(n_inputs, numCls) elif args.model == 'ResNet101': model = ResNet101(n_inputs, numCls) elif args.model == 'ResNet152': model = ResNet152(n_inputs, numCls) elif args.model == 'DenseNet121': model = DenseNet121(n_inputs, numCls) elif args.model == 'DenseNet161': model = DenseNet161(n_inputs, numCls) elif args.model == 'DenseNet169': model = DenseNet169(n_inputs, numCls) elif args.model == 'DenseNet201': model = DenseNet201(n_inputs, numCls) else: raise NameError("no model") # move model to GPU if is available if use_cuda: model = model.cuda() # define loss function if label_type == 'multi_label': lossfunc = torch.nn.BCEWithLogitsLoss() else: lossfunc = torch.nn.CrossEntropyLoss() # set up optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) best_acc = 0 start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) checkpoint_nm = os.path.basename(args.resume) sv_name = checkpoint_nm.split('_')[0] + '_' + checkpoint_nm.split('_')[1] print('saving file name is ', sv_name) if checkpoint['epoch'] > start_epoch: start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_prec'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # set up tensorboard logging train_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'training')) val_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'val')) # ----------------------------- executing Train/Val. # train network for epoch in range(start_epoch, args.epochs): print('Epoch {}/{}'.format(epoch, args.epochs - 1)) print('-' * 10) train(train_data_loader, model, optimizer, lossfunc, label_type, epoch, use_cuda, train_writer) micro_f1 = val(val_data_loader, model, optimizer, label_type, epoch, use_cuda, val_writer) is_best_acc = micro_f1 > best_acc best_acc = max(best_acc, micro_f1) save_checkpoint({ 'epoch': epoch, 'arch': args.model, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_prec': best_acc }, is_best_acc, sv_name)
def main(): global args # -------------------------- load config from file # load config config_file = args.config_file config = {} with open(config_file, 'r') as f: for line in f: (key, val) = line.split() config[(key[0:-1])] = val # Convert string to boolean boo_use_s1 = config['use_s1'] == 'True' boo_use_s2 = config['use_s2'] == 'True' boo_use_RGB = config['use_RGB'] == 'True' boo_IGBP_simple = config['IGBP_simple'] == 'True' # define label_type cf_label_type = config['label_type'] if cf_label_type == "major_vote": cf_label_type = "single_label" assert cf_label_type in label_choices wandb.init(config=config) wandb.config.update(args, allow_val_change=True) # define threshold cf_threshold = float(config['threshold']) # define labels used in cls_report if boo_IGBP_simple: ORG_LABELS = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] else: ORG_LABELS = [ '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17' ] # ----------------------------------- data # define mean/std of the training set (for data normalization) bands_mean = { 's1_mean': [-11.76858, -18.294598], 's2_mean': [ 1226.4215, 1137.3799, 1139.6792, 1350.9973, 1932.9058, 2211.1584, 2154.9846, 2409.1128, 2001.8622, 1356.0801 ] } bands_std = { 's1_std': [4.525339, 4.3586307], 's2_std': [ 741.6254, 740.883, 960.1045, 946.76056, 985.52747, 1082.4341, 1057.7628, 1136.1942, 1132.7898, 991.48016 ] } # load test dataset imgTransform = transforms.Compose( [ToTensor(), Normalize(bands_mean, bands_std)]) test_dataGen = SEN12MS(args.data_dir, args.label_split_dir, imgTransform=imgTransform, label_type=cf_label_type, threshold=cf_threshold, subset="test", use_s1=boo_use_s1, use_s2=boo_use_s2, use_RGB=boo_use_RGB, IGBP_s=boo_IGBP_simple) # number of input channels n_inputs = test_dataGen.n_inputs print('input channels =', n_inputs) # set up dataloaders test_data_loader = DataLoader(test_dataGen, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) # -------------------------------- ML setup # cuda use_cuda = torch.cuda.is_available() if use_cuda: torch.backends.cudnn.enabled = True cudnn.benchmark = True # define number of classes if boo_IGBP_simple: numCls = 10 else: numCls = 17 print('num_class: ', numCls) # define model if config['model'] == 'VGG16': model = VGG16(n_inputs, numCls) elif config['model'] == 'VGG19': model = VGG19(n_inputs, numCls) elif config['model'] == 'ResNet50' or config['model'] == 'Moco': model = ResNet50(n_inputs, numCls) elif config['model'] == 'ResNet101': model = ResNet101(n_inputs, numCls) elif config['model'] == 'ResNet152': model = ResNet152(n_inputs, numCls) elif config['model'] == 'DenseNet121': model = DenseNet121(n_inputs, numCls) elif config['model'] == 'DenseNet161': model = DenseNet161(n_inputs, numCls) elif config['model'] == 'DenseNet169': model = DenseNet169(n_inputs, numCls) elif config['model'] == 'DenseNet201': model = DenseNet201(n_inputs, numCls) else: raise NameError("no model") # move model to GPU if is available if use_cuda: model = model.cuda() # import model weights checkpoint = torch.load(args.checkpoint_pth) model.load_state_dict(checkpoint['model_state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.checkpoint_pth, checkpoint['epoch'])) print(model) # set model to evaluation mode model.eval() # define metrics prec_score_ = Precision_score() recal_score_ = Recall_score() f1_score_ = F1_score() f2_score_ = F2_score() hamming_loss_ = Hamming_loss() subset_acc_ = Subset_accuracy() acc_score_ = Accuracy_score( ) # from original script, not recommeded, seems not correct one_err_ = One_error() coverage_err_ = Coverage_error() rank_loss_ = Ranking_loss() labelAvgPrec_score_ = LabelAvgPrec_score() calssification_report_ = calssification_report(ORG_LABELS) # -------------------------------- prediction y_true = [] predicted_probs = [] with torch.no_grad(): for batch_idx, data in enumerate(tqdm(test_data_loader, desc="test")): # unpack sample bands = data["image"] labels = data["label"] # move data to gpu if model is on gpu if use_cuda: bands = bands.to(torch.device("cuda")) #labels = labels.to(torch.device("cuda")) # forward pass logits = model(bands) # convert logits to probabilies if cf_label_type == 'multi_label': probs = torch.sigmoid(logits).cpu().numpy() else: sm = torch.nn.Softmax(dim=1) probs = sm(logits).cpu().numpy() labels = labels.cpu().numpy( ) # keep true & pred label at same loc. predicted_probs += list(probs) y_true += list(labels) predicted_probs = np.asarray(predicted_probs) # convert predicted probabilities into one/multi-hot labels if cf_label_type == 'multi_label': y_predicted = (predicted_probs >= 0.5).astype(np.float32) else: loc = np.argmax(predicted_probs, axis=-1) y_predicted = np.zeros_like(predicted_probs).astype(np.float32) for i in range(len(loc)): y_predicted[i, loc[i]] = 1 y_true = np.asarray(y_true) # --------------------------- evaluation with metrics # general macro_f1, micro_f1, sample_f1 = f1_score_(y_predicted, y_true) macro_f2, micro_f2, sample_f2 = f2_score_(y_predicted, y_true) macro_prec, micro_prec, sample_prec = prec_score_(y_predicted, y_true) macro_rec, micro_rec, sample_rec = recal_score_(y_predicted, y_true) hamming_loss = hamming_loss_(y_predicted, y_true) subset_acc = subset_acc_(y_predicted, y_true) macro_acc, micro_acc, sample_acc = acc_score_(y_predicted, y_true) # ranking-based one_error = one_err_(predicted_probs, y_true) coverage_error = coverage_err_(predicted_probs, y_true) rank_loss = rank_loss_(predicted_probs, y_true) labelAvgPrec = labelAvgPrec_score_(predicted_probs, y_true) cls_report = calssification_report_(y_predicted, y_true) if cf_label_type == 'multi_label': [conf_mat, cls_acc, aa] = multi_conf_mat(y_predicted, y_true, n_classes=numCls) # the results derived from multilabel confusion matrix are not recommended to use oa = OA_multi(y_predicted, y_true) # this oa can be Jaccard index info = { "macroPrec": macro_prec, "microPrec": micro_prec, "samplePrec": sample_prec, "macroRec": macro_rec, "microRec": micro_rec, "sampleRec": sample_rec, "macroF1": macro_f1, "microF1": micro_f1, "sampleF1": sample_f1, "macroF2": macro_f2, "microF2": micro_f2, "sampleF2": sample_f2, "HammingLoss": hamming_loss, "subsetAcc": subset_acc, "macroAcc": macro_acc, "microAcc": micro_acc, "sampleAcc": sample_acc, "oneError": one_error, "coverageError": coverage_error, "rankLoss": rank_loss, "labelAvgPrec": labelAvgPrec, "clsReport": cls_report, "multilabel_conf_mat": conf_mat, "class-wise Acc": cls_acc, "AverageAcc": aa, "OverallAcc": oa } else: conf_mat = conf_mat_nor(y_predicted, y_true, n_classes=numCls) aa = get_AA(y_predicted, y_true, n_classes=numCls) # average accuracy, \ # zero-sample classes are not excluded info = { "macroPrec": macro_prec, "microPrec": micro_prec, "samplePrec": sample_prec, "macroRec": macro_rec, "microRec": micro_rec, "sampleRec": sample_rec, "macroF1": macro_f1, "microF1": micro_f1, "sampleF1": sample_f1, "macroF2": macro_f2, "microF2": micro_f2, "sampleF2": sample_f2, "HammingLoss": hamming_loss, "subsetAcc": subset_acc, "macroAcc": macro_acc, "microAcc": micro_acc, "sampleAcc": sample_acc, "oneError": one_error, "coverageError": coverage_error, "rankLoss": rank_loss, "labelAvgPrec": labelAvgPrec, "clsReport": cls_report, "conf_mat": conf_mat, "AverageAcc": aa } wandb.run.summary.update(info) print("saving metrics...") pkl.dump(info, open("test_scores.pkl", "wb"))
def choose_nets(nets_name, num_classes, operation): nets_name = nets_name.lower() if nets_name == 'bit-m-r50x1': from models.big_transfer.BigTransfer import ResnetV2 filters_factor = int(nets_name[-1]) * 4 model = ResnetV2( num_units=(3, 4, 6, 3), #From line no. 273 in BigTransfer num_outputs=21843, filters_factor=filters_factor, name="resnet", trainable=True, dtype=tf.float32) model.build((None, None, None, 3)) if operation == 'train': bit_model_file = os.path.join('./models/big_transfer/pre-trained', f'{nets_name}.h5') print('BiT pre-trained model file location:', bit_model_file) model.load_weights(bit_model_file) model._head = tf.keras.layers.Dense(units=num_classes, use_bias=True, kernel_initializer="zeros", trainable=True, name="head/dense") return model if nets_name == 'vgg11': from models.VGG import VGG11 return VGG11(num_classes) if nets_name == 'vgg13': from models.VGG import VGG13 return VGG13(num_classes) if nets_name == 'vgg16': from models.VGG import VGG16 return VGG16(num_classes) if nets_name == 'vgg19': from models.VGG import VGG19 return VGG19(num_classes) if nets_name == 'resnet18': from models.ResNet import ResNet18 return ResNet18(num_classes) if nets_name == 'resnet34': from models.ResNet import ResNet34 return ResNet34(num_classes) if nets_name == 'resnet50': from models.ResNet import ResNet50 return ResNet50(num_classes) if nets_name == 'resnet101': from models.ResNet import ResNet101 return ResNet101(num_classes) if nets_name == 'resnet152': from models.ResNet import ResNet152 return ResNet152(num_classes) if nets_name == 'googlenet': from models.GoogLeNet import GoogLeNet return GoogLeNet(num_classes) if nets_name == 'inceptionv3': from models.InceptionV3 import inceptionv3 return inceptionv3(num_classes) if nets_name == 'mobilenet': from models.MobileNet import mobilenet return mobilenet(num_classes) if nets_name == 'mobilenetv2': from models.MobileNetV2 import mobilenetv2 return mobilenetv2(num_classes) if nets_name == 'seresnet18': from models.SEResNet import seresnet18 return seresnet18(num_classes) if nets_name == 'seresnet34': from models.SEResNet import seresnet34 return seresnet34(num_classes) if nets_name == 'seresnet50': from models.SEResNet import seresnet50 return seresnet50(num_classes) if nets_name == 'seresnet101': from models.SEResNet import seresnet101 return seresnet101(num_classes) if nets_name == 'seresnet152': from models.SEResNet import seresnet152 return seresnet152(num_classes) if nets_name == 'densenet121': from models.DenseNet import densenet121 return densenet121(num_classes) if nets_name == 'densenet169': from models.DenseNet import densenet169 return densenet169(num_classes) if nets_name == 'densenet201': from models.DenseNet import densenet201 return densenet201(num_classes) if nets_name == 'densenet121': from models.DenseNet import densenet161 return densenet161(num_classes) if nets_name == 'squeezenet': from models.SqueezeNet import squeezenet return squeezenet(num_classes) if nets_name == 'inceptionv4': from models.InceptionV4 import inceptionv4 return inceptionv4(num_classes) if nets_name == 'inception-resnet-v2': from models.InceptionV4 import inception_resnet_v2 return inception_resnet_v2(num_classes) raise NotImplementedError