Beispiel #1
0
def main():
    global args
    global sv_name_eval
    # save configuration to file
    sv_name = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S')
    sv_name_eval = sv_name
    print('saving file name is ', sv_name)

    write_arguments_to_file(args,
                            os.path.join(logs_dir, sv_name + '_arguments.txt'))

    # ----------------------------------- data
    # define mean/std of the training set (for data normalization)
    label_type = args.label_type
    use_s1 = (args.sensor_type == 's1') | (args.sensor_type == 's1s2')
    use_s2 = (args.sensor_type == 's2') | (args.sensor_type == 's1s2')

    dataset = args.dataset
    data_dir = os.path.join("data", dataset, "data")

    bands_mean = {}
    bands_std = {}
    train_dataGen = None
    val_dataGen = None
    test_dataGen = None

    print(f"Using {dataset} dataset")
    if dataset == 'sen12ms':
        bands_mean = {
            's1_mean': [-11.76858, -18.294598],
            's2_mean': [
                1226.4215, 1137.3799, 1139.6792, 1350.9973, 1932.9058,
                2211.1584, 2154.9846, 2409.1128, 2001.8622, 1356.0801
            ]
        }
        bands_std = {
            's1_std': [4.525339, 4.3586307],
            's2_std': [
                741.6254, 740.883, 960.1045, 946.76056, 985.52747, 1082.4341,
                1057.7628, 1136.1942, 1132.7898, 991.48016
            ]
        }
    elif dataset == 'bigearthnet':
        # THE S2 BAND STATISTICS WERE PROVIDED BY THE BIGEARTHNET TEAM
        # Source: https://git.tu-berlin.de/rsim/bigearthnet-models-tf/-/blob/master/BigEarthNet.py
        bands_mean = {
            's1_mean': [-12.619993, -19.290445],
            's2_mean': [
                340.76769064, 429.9430203, 614.21682446, 590.23569706,
                950.68368468, 1792.46290469, 2075.46795189, 2218.94553375,
                2266.46036911, 2246.0605464, 1594.42694882, 1009.32729131
            ]
        }
        bands_std = {
            's1_std': [5.115911, 5.464428],
            's2_std': [
                554.81258967, 572.41639287, 582.87945694, 675.88746967,
                729.89827633, 1096.01480586, 1273.45393088, 1365.45589904,
                1356.13789355, 1302.3292881, 1079.19066363, 818.86747235
            ]
        }
    else:
        raise NameError(f"unknown dataset: {dataset}")

    # load datasets
    imgTransform = transforms.Compose(
        [ToTensor(), Normalize(bands_mean, bands_std)])
    if dataset == 'sen12ms':
        train_dataGen = SEN12MS(data_dir,
                                args.label_split_dir,
                                imgTransform=imgTransform,
                                label_type=label_type,
                                threshold=args.threshold,
                                subset="train",
                                use_s1=use_s1,
                                use_s2=use_s2,
                                use_RGB=args.use_RGB,
                                IGBP_s=args.simple_scheme,
                                data_size=args.data_size,
                                sensor_type=args.sensor_type,
                                use_fusion=args.use_fusion)

        val_dataGen = SEN12MS(data_dir,
                              args.label_split_dir,
                              imgTransform=imgTransform,
                              label_type=label_type,
                              threshold=args.threshold,
                              subset="val",
                              use_s1=use_s1,
                              use_s2=use_s2,
                              use_RGB=args.use_RGB,
                              IGBP_s=args.simple_scheme,
                              data_size=args.data_size,
                              sensor_type=args.sensor_type,
                              use_fusion=args.use_fusion)

        if args.eval:
            test_dataGen = SEN12MS(data_dir,
                                   args.label_split_dir,
                                   imgTransform=imgTransform,
                                   label_type=label_type,
                                   threshold=args.threshold,
                                   subset="test",
                                   use_s1=use_s1,
                                   use_s2=use_s2,
                                   use_RGB=args.use_RGB,
                                   IGBP_s=args.simple_scheme,
                                   sensor_type=args.sensor_type,
                                   use_fusion=args.use_fusion)
    else:
        # Assume bigearthnet
        train_dataGen = BigEarthNet(data_dir,
                                    args.label_split_dir,
                                    imgTransform=imgTransform,
                                    label_type=label_type,
                                    threshold=args.threshold,
                                    subset="train",
                                    use_s1=use_s1,
                                    use_s2=use_s2,
                                    use_RGB=args.use_RGB,
                                    CLC_s=args.simple_scheme,
                                    data_size=args.data_size,
                                    sensor_type=args.sensor_type,
                                    use_fusion=args.use_fusion)

        val_dataGen = BigEarthNet(data_dir,
                                  args.label_split_dir,
                                  imgTransform=imgTransform,
                                  label_type=label_type,
                                  threshold=args.threshold,
                                  subset="val",
                                  use_s1=use_s1,
                                  use_s2=use_s2,
                                  use_RGB=args.use_RGB,
                                  CLC_s=args.simple_scheme,
                                  data_size=args.data_size,
                                  sensor_type=args.sensor_type,
                                  use_fusion=args.use_fusion)

        if args.eval:
            test_dataGen = BigEarthNet(data_dir,
                                       args.label_split_dir,
                                       imgTransform=imgTransform,
                                       label_type=label_type,
                                       threshold=args.threshold,
                                       subset="test",
                                       use_s1=use_s1,
                                       use_s2=use_s2,
                                       use_RGB=args.use_RGB,
                                       CLC_s=args.simple_scheme,
                                       sensor_type=args.sensor_type,
                                       use_fusion=args.use_fusion)

    # number of input channels
    n_inputs = train_dataGen.n_inputs
    print('input channels =', n_inputs)
    wandb.config.update({"input_channels": n_inputs})

    # set up dataloaders
    train_data_loader = DataLoader(train_dataGen,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   shuffle=True,
                                   pin_memory=True)
    val_data_loader = DataLoader(val_dataGen,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_workers,
                                 shuffle=False,
                                 pin_memory=True)

    if args.eval:
        test_data_loader = DataLoader(test_dataGen,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_workers,
                                      shuffle=False,
                                      pin_memory=True)

# -------------------------------- ML setup
# cuda
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.backends.cudnn.enabled = True
        cudnn.benchmark = True

    # define number of classes
    if dataset == 'sen12ms':
        if args.simple_scheme:
            numCls = 10
            ORG_LABELS = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
        else:
            numCls = 17
            ORG_LABELS = [
                '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
                '13', '14', '15', '16', '17'
            ]
    else:
        if args.simple_scheme:
            numCls = 19
            ORG_LABELS = [
                '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
                '13', '14', '15', '16', '17', '18', '19'
            ]
        else:
            numCls = 43
            ORG_LABELS = [
                '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
                '13', '14', '15', '16', '17', '18', '19', '20', '21', '22',
                '23', '24', '25', '26', '27', '28', '29', '30', '31', '32',
                '33', '34', '35', '36', '37', '38', '39', '40', '41', '42',
                '43'
            ]

    print('num_class: ', numCls)
    wandb.config.update({"n_class": numCls})

    # define model
    if args.model == 'VGG16':
        model = VGG16(n_inputs, numCls)
    elif args.model == 'VGG19':
        model = VGG19(n_inputs, numCls)
    elif args.model == 'Supervised':
        model = ResNet50(n_inputs, numCls)
    elif args.model == 'Supervised_1x1':
        model = ResNet50_1x1(n_inputs, numCls)
    elif args.model == 'ResNet101':
        model = ResNet101(n_inputs, numCls)
    elif args.model == 'ResNet152':
        model = ResNet152(n_inputs, numCls)
    elif args.model == 'DenseNet121':
        model = DenseNet121(n_inputs, numCls)
    elif args.model == 'DenseNet161':
        model = DenseNet161(n_inputs, numCls)
    elif args.model == 'DenseNet169':
        model = DenseNet169(n_inputs, numCls)
    elif args.model == 'DenseNet201':
        model = DenseNet201(n_inputs, numCls)
    # finetune moco pre-trained model
    elif args.model.startswith("Moco"):
        pt_path = os.path.join(args.pt_dir, f"{args.pt_name}.pth")
        print(pt_path)
        assert os.path.exists(pt_path)
        if args.model == 'Moco':
            print("transfer backbone weights but no conv 1x1 input module")
            model = Moco(torch.load(pt_path), n_inputs, numCls)
        elif args.model == 'Moco_1x1':
            print("transfer backbone weights and input module weights")
            model = Moco_1x1(torch.load(pt_path), n_inputs, numCls)
        elif args.model == 'Moco_1x1RND':
            print(
                "transfer backbone weights but initialize input module random with random weights"
            )
            model = Moco_1x1(torch.load(pt_path), n_inputs, numCls)
        else:  # Assume Moco2 at present
            raise NameError("no model")
    else:
        raise NameError("no model")

    print(model)

    # move model to GPU if is available
    if use_cuda:
        model = model.cuda()

    # define loss function
    if label_type == 'multi_label':
        lossfunc = torch.nn.BCEWithLogitsLoss()
    else:
        lossfunc = torch.nn.CrossEntropyLoss()

    # set up optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.decay)

    best_acc = 0
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            checkpoint_nm = os.path.basename(args.resume)
            sv_name = checkpoint_nm.split('_')[0] + '_' + checkpoint_nm.split(
                '_')[1]
            print('saving file name is ', sv_name)

            if checkpoint['epoch'] > start_epoch:
                start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_prec']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # set up tensorboard logging
    # train_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'training'))
    # val_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'val'))


# ----------------------------- executing Train/Val.
# train network
# wandb.watch(model, log="all")

    scheduler = None
    if args.use_lr_step:
        # Ex: If initial Lr is 0.0001, step size is 25, and gamma is 0.1, then lr will be changed for every 20 steps
        # 0.0001 - first 25 epochs
        # 0.00001 - 25 to 50 epochs
        # 0.000001 - 50 to 75 epochs
        # 0.0000001 - 75 to 100 epochs
        # https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.lr_step_size,
                                              gamma=args.lr_step_gamma)

    for epoch in range(start_epoch, args.epochs):
        if args.use_lr_step:
            scheduler.step()
            print('Epoch {}/{} lr: {}'.format(epoch, args.epochs - 1,
                                              optimizer.param_groups[0]['lr']))
        else:
            print('Epoch {}/{}'.format(epoch, args.epochs - 1))
        print('-' * 25)

        train(train_data_loader, model, optimizer, lossfunc, label_type, epoch,
              use_cuda)
        micro_f1 = val(val_data_loader, model, optimizer, label_type, epoch,
                       use_cuda)

        is_best_acc = micro_f1 > best_acc
        best_acc = max(best_acc, micro_f1)

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.model,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_prec': best_acc
            }, is_best_acc, sv_name)

        wandb.log({'epoch': epoch, 'micro_f1': micro_f1})

    print("=============")
    print("done training")
    print("=============")

    if args.eval:
        eval(test_data_loader, model, label_type, numCls, use_cuda, ORG_LABELS)
Beispiel #2
0
def choose_nets(nets_name, num_classes=100):
    nets_name = nets_name.lower()
    if nets_name == 'vgg11':
        from models.VGG import VGG11
        return VGG11(num_classes)
    if nets_name == 'vgg13':
        from models.VGG import VGG13
        return VGG13(num_classes)
    if nets_name == 'VGG16':
        from models.VGG import VGG16
        return VGG16(num_classes)
    if nets_name == 'vgg19':
        from models.VGG import VGG19
        return VGG19(num_classes)
    if nets_name == 'resnet18':
        from models.ResNet import ResNet18
        return ResNet18(num_classes)
    if nets_name == 'resnet34':
        from models.ResNet import ResNet34
        return ResNet34(num_classes)
    if nets_name == 'resnet50':
        from models.ResNet import ResNet50
        return ResNet50(num_classes)
    if nets_name == 'resnet101':
        from models.ResNet import ResNet101
        return ResNet101(num_classes)
    if nets_name == 'resnet152':
        from models.ResNet import ResNet152
        return ResNet152(num_classes)
    if nets_name == 'googlenet':
        from models.GoogLeNet import GoogLeNet
        return GoogLeNet(num_classes)
    if nets_name == 'inceptionv3':
        from models.InceptionV3 import inceptionv3
        return inceptionv3(num_classes)
    if nets_name == 'mobilenet':
        from models.MobileNet import mobilenet
        return mobilenet(num_classes)
    if nets_name == 'mobilenetv2':
        from models.MobileNetV2 import mobilenetv2
        return mobilenetv2(num_classes)
    if nets_name == 'seresnet18':
        from models.SEResNet import seresnet18
        return seresnet18(num_classes)
    if nets_name == 'seresnet34':
        from models.SEResNet import seresnet34
        return seresnet34(num_classes)
    if nets_name == 'seresnet50':
        from models.SEResNet import seresnet50
        return seresnet50(num_classes)
    if nets_name == 'seresnet101':
        from models.SEResNet import seresnet101
        return seresnet101(num_classes)
    if nets_name == 'seresnet152':
        from models.SEResNet import seresnet152
        return seresnet152(num_classes)
    if nets_name == 'densenet121':
        from models.DenseNet import densenet121
        return densenet121(num_classes)
    if nets_name == 'densenet169':
        from models.DenseNet import densenet169
        return densenet169(num_classes)
    if nets_name == 'densenet201':
        from models.DenseNet import densenet201
        return densenet201(num_classes)
    if nets_name == 'densenet121':
        from models.DenseNet import densenet161
        return densenet161(num_classes)
    if nets_name == 'squeezenet':
        from models.SqueezeNet import squeezenet
        return squeezenet(num_classes)
    if nets_name == 'inceptionv4':
        from models.InceptionV4 import inceptionv4
        return inceptionv4(num_classes)
    if nets_name == 'inception-resnet-v2':
        from models.InceptionV4 import inception_resnet_v2
        return inception_resnet_v2(num_classes)
    raise NotImplementedError
Beispiel #3
0
def main():
    global args
    
    # save configuration to file
    sv_name = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S')
    print('saving file name is ', sv_name)

    write_arguments_to_file(args, os.path.join(logs_dir, sv_name+'_arguments.txt'))

# ----------------------------------- data
    # define mean/std of the training set (for data normalization)
    label_type = args.label_type
        
    bands_mean = {'s1_mean': [-11.76858, -18.294598],
                  's2_mean': [1226.4215, 1137.3799, 1139.6792, 1350.9973, 1932.9058,
                              2211.1584, 2154.9846, 2409.1128, 2001.8622, 1356.0801]}
                  
    bands_std = {'s1_std': [4.525339, 4.3586307],
                 's2_std': [741.6254, 740.883, 960.1045, 946.76056, 985.52747,
                            1082.4341, 1057.7628, 1136.1942, 1132.7898, 991.48016]} 

    
    # load datasets 
    imgTransform = transforms.Compose([ToTensor(),Normalize(bands_mean, bands_std)])
    
    train_dataGen = SEN12MS(args.data_dir, args.label_split_dir, 
                            imgTransform=imgTransform, 
                            label_type=label_type, threshold=args.threshold, subset="train", 
                            use_s1=args.use_s1, use_s2=args.use_s2, use_RGB=args.use_RGB,
                            IGBP_s=args.IGBP_simple)
    
    val_dataGen = SEN12MS(args.data_dir, args.label_split_dir, 
                          imgTransform=imgTransform, 
                          label_type=label_type, threshold=args.threshold, subset="val", 
                          use_s1=args.use_s1, use_s2=args.use_s2, use_RGB=args.use_RGB,
                          IGBP_s=args.IGBP_simple)    
    
    
    # number of input channels
    n_inputs = train_dataGen.n_inputs 
    print('input channels =', n_inputs)
    
    # set up dataloaders
    train_data_loader = DataLoader(train_dataGen, 
                                   batch_size=args.batch_size, 
                                   num_workers=args.num_workers, 
                                   shuffle=True, 
                                   pin_memory=True)
    val_data_loader = DataLoader(val_dataGen, 
                                 batch_size=args.batch_size, 
                                 num_workers=args.num_workers, 
                                 shuffle=False, 
                                 pin_memory=True)

# -------------------------------- ML setup
    # cuda
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        torch.backends.cudnn.enabled = True
        cudnn.benchmark = True

    # define number of classes
    if args.IGBP_simple:
        numCls = 10
    else:
        numCls = 17
    
    print('num_class: ', numCls)
    
    # define model
    if args.model == 'VGG16':
        model = VGG16(n_inputs, numCls)
    elif args.model == 'VGG19':
        model = VGG19(n_inputs, numCls)
        
    elif args.model == 'ResNet50':
        model = ResNet50(n_inputs, numCls)
    elif args.model == 'ResNet101':
        model = ResNet101(n_inputs, numCls)
    elif args.model == 'ResNet152':
        model = ResNet152(n_inputs, numCls)
        
    elif args.model == 'DenseNet121':
        model = DenseNet121(n_inputs, numCls)
    elif args.model == 'DenseNet161':
        model = DenseNet161(n_inputs, numCls)
    elif args.model == 'DenseNet169':
        model = DenseNet169(n_inputs, numCls)
    elif args.model == 'DenseNet201':
        model = DenseNet201(n_inputs, numCls)     
    else:
        raise NameError("no model")

    
    # move model to GPU if is available
    if use_cuda:
        model = model.cuda() 

    # define loss function
    if label_type == 'multi_label':
        lossfunc = torch.nn.BCEWithLogitsLoss()
    else:
        lossfunc = torch.nn.CrossEntropyLoss()

    
    # set up optimizer 
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)

    best_acc = 0
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            checkpoint_nm = os.path.basename(args.resume)
            sv_name = checkpoint_nm.split('_')[0] + '_' + checkpoint_nm.split('_')[1]
            print('saving file name is ', sv_name)

            if checkpoint['epoch'] > start_epoch:
                start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_prec']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))


    # set up tensorboard logging
    train_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'training'))
    val_writer = SummaryWriter(os.path.join(logs_dir, 'runs', sv_name, 'val'))


# ----------------------------- executing Train/Val. 
    # train network
    for epoch in range(start_epoch, args.epochs):

        print('Epoch {}/{}'.format(epoch, args.epochs - 1))
        print('-' * 10)

        train(train_data_loader, model, optimizer, lossfunc, label_type, epoch, use_cuda, train_writer)
        micro_f1 = val(val_data_loader, model, optimizer, label_type, epoch, use_cuda, val_writer)

        is_best_acc = micro_f1 > best_acc
        best_acc = max(best_acc, micro_f1)

        save_checkpoint({
            'epoch': epoch,
            'arch': args.model,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_prec': best_acc
            }, is_best_acc, sv_name)
Beispiel #4
0
def main():
    global args

    # -------------------------- load config from file
    # load config
    config_file = args.config_file

    config = {}
    with open(config_file, 'r') as f:
        for line in f:
            (key, val) = line.split()
            config[(key[0:-1])] = val

    # Convert string to boolean
    boo_use_s1 = config['use_s1'] == 'True'
    boo_use_s2 = config['use_s2'] == 'True'
    boo_use_RGB = config['use_RGB'] == 'True'
    boo_IGBP_simple = config['IGBP_simple'] == 'True'

    # define label_type
    cf_label_type = config['label_type']
    if cf_label_type == "major_vote":
        cf_label_type = "single_label"
    assert cf_label_type in label_choices

    wandb.init(config=config)
    wandb.config.update(args, allow_val_change=True)

    # define threshold
    cf_threshold = float(config['threshold'])

    # define labels used in cls_report
    if boo_IGBP_simple:
        ORG_LABELS = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    else:
        ORG_LABELS = [
            '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12',
            '13', '14', '15', '16', '17'
        ]


# ----------------------------------- data
# define mean/std of the training set (for data normalization)
    bands_mean = {
        's1_mean': [-11.76858, -18.294598],
        's2_mean': [
            1226.4215, 1137.3799, 1139.6792, 1350.9973, 1932.9058, 2211.1584,
            2154.9846, 2409.1128, 2001.8622, 1356.0801
        ]
    }

    bands_std = {
        's1_std': [4.525339, 4.3586307],
        's2_std': [
            741.6254, 740.883, 960.1045, 946.76056, 985.52747, 1082.4341,
            1057.7628, 1136.1942, 1132.7898, 991.48016
        ]
    }

    # load test dataset
    imgTransform = transforms.Compose(
        [ToTensor(), Normalize(bands_mean, bands_std)])

    test_dataGen = SEN12MS(args.data_dir,
                           args.label_split_dir,
                           imgTransform=imgTransform,
                           label_type=cf_label_type,
                           threshold=cf_threshold,
                           subset="test",
                           use_s1=boo_use_s1,
                           use_s2=boo_use_s2,
                           use_RGB=boo_use_RGB,
                           IGBP_s=boo_IGBP_simple)

    # number of input channels
    n_inputs = test_dataGen.n_inputs
    print('input channels =', n_inputs)

    # set up dataloaders
    test_data_loader = DataLoader(test_dataGen,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=False,
                                  pin_memory=True)

    # -------------------------------- ML setup
    # cuda
    use_cuda = torch.cuda.is_available()

    if use_cuda:
        torch.backends.cudnn.enabled = True
        cudnn.benchmark = True

    # define number of classes
    if boo_IGBP_simple:
        numCls = 10
    else:
        numCls = 17

    print('num_class: ', numCls)

    # define model
    if config['model'] == 'VGG16':
        model = VGG16(n_inputs, numCls)
    elif config['model'] == 'VGG19':
        model = VGG19(n_inputs, numCls)

    elif config['model'] == 'ResNet50' or config['model'] == 'Moco':
        model = ResNet50(n_inputs, numCls)
    elif config['model'] == 'ResNet101':
        model = ResNet101(n_inputs, numCls)
    elif config['model'] == 'ResNet152':
        model = ResNet152(n_inputs, numCls)

    elif config['model'] == 'DenseNet121':
        model = DenseNet121(n_inputs, numCls)
    elif config['model'] == 'DenseNet161':
        model = DenseNet161(n_inputs, numCls)
    elif config['model'] == 'DenseNet169':
        model = DenseNet169(n_inputs, numCls)
    elif config['model'] == 'DenseNet201':
        model = DenseNet201(n_inputs, numCls)
    else:
        raise NameError("no model")

    # move model to GPU if is available
    if use_cuda:
        model = model.cuda()

    # import model weights
    checkpoint = torch.load(args.checkpoint_pth)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        args.checkpoint_pth, checkpoint['epoch']))

    print(model)

    # set model to evaluation mode
    model.eval()

    # define metrics
    prec_score_ = Precision_score()
    recal_score_ = Recall_score()
    f1_score_ = F1_score()
    f2_score_ = F2_score()
    hamming_loss_ = Hamming_loss()
    subset_acc_ = Subset_accuracy()
    acc_score_ = Accuracy_score(
    )  # from original script, not recommeded, seems not correct
    one_err_ = One_error()
    coverage_err_ = Coverage_error()
    rank_loss_ = Ranking_loss()
    labelAvgPrec_score_ = LabelAvgPrec_score()

    calssification_report_ = calssification_report(ORG_LABELS)

    # -------------------------------- prediction
    y_true = []
    predicted_probs = []

    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_data_loader, desc="test")):

            # unpack sample
            bands = data["image"]
            labels = data["label"]

            # move data to gpu if model is on gpu
            if use_cuda:
                bands = bands.to(torch.device("cuda"))
                #labels = labels.to(torch.device("cuda"))

            # forward pass
            logits = model(bands)

            # convert logits to probabilies
            if cf_label_type == 'multi_label':
                probs = torch.sigmoid(logits).cpu().numpy()
            else:
                sm = torch.nn.Softmax(dim=1)
                probs = sm(logits).cpu().numpy()

            labels = labels.cpu().numpy(
            )  # keep true & pred label at same loc.
            predicted_probs += list(probs)
            y_true += list(labels)

    predicted_probs = np.asarray(predicted_probs)
    # convert predicted probabilities into one/multi-hot labels
    if cf_label_type == 'multi_label':
        y_predicted = (predicted_probs >= 0.5).astype(np.float32)
    else:
        loc = np.argmax(predicted_probs, axis=-1)
        y_predicted = np.zeros_like(predicted_probs).astype(np.float32)
        for i in range(len(loc)):
            y_predicted[i, loc[i]] = 1

    y_true = np.asarray(y_true)

    # --------------------------- evaluation with metrics
    # general
    macro_f1, micro_f1, sample_f1 = f1_score_(y_predicted, y_true)
    macro_f2, micro_f2, sample_f2 = f2_score_(y_predicted, y_true)
    macro_prec, micro_prec, sample_prec = prec_score_(y_predicted, y_true)
    macro_rec, micro_rec, sample_rec = recal_score_(y_predicted, y_true)
    hamming_loss = hamming_loss_(y_predicted, y_true)
    subset_acc = subset_acc_(y_predicted, y_true)
    macro_acc, micro_acc, sample_acc = acc_score_(y_predicted, y_true)
    # ranking-based
    one_error = one_err_(predicted_probs, y_true)
    coverage_error = coverage_err_(predicted_probs, y_true)
    rank_loss = rank_loss_(predicted_probs, y_true)
    labelAvgPrec = labelAvgPrec_score_(predicted_probs, y_true)

    cls_report = calssification_report_(y_predicted, y_true)

    if cf_label_type == 'multi_label':
        [conf_mat, cls_acc, aa] = multi_conf_mat(y_predicted,
                                                 y_true,
                                                 n_classes=numCls)
        # the results derived from multilabel confusion matrix are not recommended to use
        oa = OA_multi(y_predicted, y_true)
        # this oa can be Jaccard index

        info = {
            "macroPrec": macro_prec,
            "microPrec": micro_prec,
            "samplePrec": sample_prec,
            "macroRec": macro_rec,
            "microRec": micro_rec,
            "sampleRec": sample_rec,
            "macroF1": macro_f1,
            "microF1": micro_f1,
            "sampleF1": sample_f1,
            "macroF2": macro_f2,
            "microF2": micro_f2,
            "sampleF2": sample_f2,
            "HammingLoss": hamming_loss,
            "subsetAcc": subset_acc,
            "macroAcc": macro_acc,
            "microAcc": micro_acc,
            "sampleAcc": sample_acc,
            "oneError": one_error,
            "coverageError": coverage_error,
            "rankLoss": rank_loss,
            "labelAvgPrec": labelAvgPrec,
            "clsReport": cls_report,
            "multilabel_conf_mat": conf_mat,
            "class-wise Acc": cls_acc,
            "AverageAcc": aa,
            "OverallAcc": oa
        }

    else:
        conf_mat = conf_mat_nor(y_predicted, y_true, n_classes=numCls)
        aa = get_AA(y_predicted, y_true,
                    n_classes=numCls)  # average accuracy, \
        # zero-sample classes are not excluded

        info = {
            "macroPrec": macro_prec,
            "microPrec": micro_prec,
            "samplePrec": sample_prec,
            "macroRec": macro_rec,
            "microRec": micro_rec,
            "sampleRec": sample_rec,
            "macroF1": macro_f1,
            "microF1": micro_f1,
            "sampleF1": sample_f1,
            "macroF2": macro_f2,
            "microF2": micro_f2,
            "sampleF2": sample_f2,
            "HammingLoss": hamming_loss,
            "subsetAcc": subset_acc,
            "macroAcc": macro_acc,
            "microAcc": micro_acc,
            "sampleAcc": sample_acc,
            "oneError": one_error,
            "coverageError": coverage_error,
            "rankLoss": rank_loss,
            "labelAvgPrec": labelAvgPrec,
            "clsReport": cls_report,
            "conf_mat": conf_mat,
            "AverageAcc": aa
        }

    wandb.run.summary.update(info)
    print("saving metrics...")
    pkl.dump(info, open("test_scores.pkl", "wb"))
def choose_nets(nets_name, num_classes, operation):
    nets_name = nets_name.lower()
    if nets_name == 'bit-m-r50x1':
        from models.big_transfer.BigTransfer import ResnetV2
        filters_factor = int(nets_name[-1]) * 4
        model = ResnetV2(
            num_units=(3, 4, 6, 3),  #From line no. 273 in BigTransfer
            num_outputs=21843,
            filters_factor=filters_factor,
            name="resnet",
            trainable=True,
            dtype=tf.float32)

        model.build((None, None, None, 3))

        if operation == 'train':
            bit_model_file = os.path.join('./models/big_transfer/pre-trained',
                                          f'{nets_name}.h5')
            print('BiT pre-trained model file location:', bit_model_file)
            model.load_weights(bit_model_file)

        model._head = tf.keras.layers.Dense(units=num_classes,
                                            use_bias=True,
                                            kernel_initializer="zeros",
                                            trainable=True,
                                            name="head/dense")

        return model

    if nets_name == 'vgg11':
        from models.VGG import VGG11
        return VGG11(num_classes)
    if nets_name == 'vgg13':
        from models.VGG import VGG13
        return VGG13(num_classes)
    if nets_name == 'vgg16':
        from models.VGG import VGG16
        return VGG16(num_classes)
    if nets_name == 'vgg19':
        from models.VGG import VGG19
        return VGG19(num_classes)
    if nets_name == 'resnet18':
        from models.ResNet import ResNet18
        return ResNet18(num_classes)
    if nets_name == 'resnet34':
        from models.ResNet import ResNet34
        return ResNet34(num_classes)
    if nets_name == 'resnet50':
        from models.ResNet import ResNet50
        return ResNet50(num_classes)
    if nets_name == 'resnet101':
        from models.ResNet import ResNet101
        return ResNet101(num_classes)
    if nets_name == 'resnet152':
        from models.ResNet import ResNet152
        return ResNet152(num_classes)
    if nets_name == 'googlenet':
        from models.GoogLeNet import GoogLeNet
        return GoogLeNet(num_classes)
    if nets_name == 'inceptionv3':
        from models.InceptionV3 import inceptionv3
        return inceptionv3(num_classes)
    if nets_name == 'mobilenet':
        from models.MobileNet import mobilenet
        return mobilenet(num_classes)
    if nets_name == 'mobilenetv2':
        from models.MobileNetV2 import mobilenetv2
        return mobilenetv2(num_classes)
    if nets_name == 'seresnet18':
        from models.SEResNet import seresnet18
        return seresnet18(num_classes)
    if nets_name == 'seresnet34':
        from models.SEResNet import seresnet34
        return seresnet34(num_classes)
    if nets_name == 'seresnet50':
        from models.SEResNet import seresnet50
        return seresnet50(num_classes)
    if nets_name == 'seresnet101':
        from models.SEResNet import seresnet101
        return seresnet101(num_classes)
    if nets_name == 'seresnet152':
        from models.SEResNet import seresnet152
        return seresnet152(num_classes)
    if nets_name == 'densenet121':
        from models.DenseNet import densenet121
        return densenet121(num_classes)
    if nets_name == 'densenet169':
        from models.DenseNet import densenet169
        return densenet169(num_classes)
    if nets_name == 'densenet201':
        from models.DenseNet import densenet201
        return densenet201(num_classes)
    if nets_name == 'densenet121':
        from models.DenseNet import densenet161
        return densenet161(num_classes)
    if nets_name == 'squeezenet':
        from models.SqueezeNet import squeezenet
        return squeezenet(num_classes)
    if nets_name == 'inceptionv4':
        from models.InceptionV4 import inceptionv4
        return inceptionv4(num_classes)
    if nets_name == 'inception-resnet-v2':
        from models.InceptionV4 import inception_resnet_v2
        return inception_resnet_v2(num_classes)
    raise NotImplementedError