コード例 #1
0
ファイル: trainer.py プロジェクト: jay1009/DFN
    def build_model(self):
        if self.cfg.model == 'unet':
            self.model = unet.UNet(num_classes=2, in_dim=3, conv_dim=64)
        elif self.cfg.model == 'fcn8':
            self.model = fcn.FCN8(num_classes=2)
        elif self.cfg.model == 'pspnet_avg':
            self.model = pspnet.PSPNet(num_classes=2, pool_type='avg')
        elif self.cfg.model == 'pspnet_max':
            self.model = pspnet.PSPNet(num_classes=2, pool_type='max')
        elif self.cfg.model == 'dfnet':
            self.model = dfn.SmoothNet(num_classes=2,
                                       h_image_size=self.cfg.h_image_size,
                                       w_image_size=self.cfg.w_image_size)
        self.optim = optim.Adam(self.model.parameters(),
                                lr=self.cfg.lr,
                                betas=[self.cfg.beta1, self.cfg.beta2])
        # Poly learning rate policy
        lr_lambda = lambda n_iter: (1 - n_iter / self.cfg.n_iters
                                    )**self.cfg.lr_exp
        self.scheduler = LambdaLR(self.optim, lr_lambda=lr_lambda)
        #weight = torch.tensor([0.04,0.07])
        #criterion = nn.CrossEntropyLoss(weight=weight)
        self.c_loss = nn.CrossEntropyLoss().to(self.device)
        self.softmax = nn.Softmax(dim=1).to(
            self.device)  # channel-wise softmax

        self.n_gpu = torch.cuda.device_count()
        if self.n_gpu > 1:
            print('Use data parallel model(# gpu: {})'.format(self.n_gpu))
            self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)

        if USE_NSML:
            self.viz = Visdom(visdom=visdom)
コード例 #2
0
    def build_model(self):

        self.model = unet.UNet(num_classes=21, in_dim=3, conv_dim=64)
        self.optim = optim.Adam(
            self.model.parameters(
            ),  #usiamo adam per ottimizzazione stocastica come OPTIM, passangogli i parametri
            lr=self.cfg.lr,  #settiamo il learning rate
            betas=[self.cfg.beta1, self.cfg.beta2]
        )  #le due Beta, cioe' la probabilita' di accettare l'ipotesi quando e' falsa  (coefficients used for computing running averages of gradient and its square )
        lr_lambda = lambda n_iter: (
            1 - n_iter / self.cfg.n_iters
        )**self.cfg.lr_exp  #ATTENZIONE: learning rate LAMBDA penso
        self.scheduler = LambdaLR(self.optim, lr_lambda=lr_lambda)
        self.c_loss = nn.CrossEntropyLoss().to(
            self.device)  #crossEntropy ! muove il modello nella GPU
        self.softmax = nn.Softmax(dim=1).to(
            self.device
        )  # channel-wise softmax             #facciamo il softmax, cioe' prendiamo tutte le probabilita' e facciamo in modo che la loro somma sia 1

        self.n_gpu = torch.cuda.device_count(
        )  #ritorna il numero di GPU a disposizione
        if self.cfg.continue_train:
            self.load_network(self.model, "UNET_VOC", self.cfg.which_epoch,
                              self.start_epoch, self.optim, self.scheduler)
        if self.n_gpu > 1:
            print('Use data parallel model(# gpu: {})'.format(self.n_gpu))
            self.model = nn.DataParallel(
                self.model)  #implementa il parallelismo, se disponibile
        self.model = self.model.to(self.device)
        if self.n_gpu > 0:
            torch.backends.cudnn.benchmark = True
            for state in self.optim.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()
コード例 #3
0
def load_model(opt, checkpoint_dir):
    checkpoint_list = glob.glob(os.path.join(checkpoint_dir, "*.pth"))
    checkpoint_list.sort()

    # 은별 : return n_epoch +1 이라고 했는데 n_epoch이 정의x
    n_epoch = 0

    loss_list = list(
        map(lambda x: float(os.path.basename(x).split('_')[4][:-4]),
            checkpoint_list))
    best_loss_idx = loss_list.index(min(loss_list))
    checkpoint_path = checkpoint_list[best_loss_idx]

    if opt.model == 'unet':
        net = unet.UNet(opt.num_class + 1)

    if os.path.isfile(checkpoint_path):
        # print("=> loading checkpoint '{}'".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)

        n_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['net'].state_dict())
        print("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_path, n_epoch))
    else:
        print("=> no checkpoint found at '{}'".format(checkpoint_path))
        # 은별 : 0으로 바꿀 필요는 없을 것 같아
        # n_epoch = 0

    return n_epoch + 1, net
コード例 #4
0
ファイル: train_RV_adv.py プロジェクト: XiaowenK/UNet_Family
def getModel(num_class, pretrained):
    if M_G == "UNet":
        model_g = unet.UNet(num_class=num_class).to(device)
    elif M_G == "ResUNet":
        model_g = resunet.ResUNet(num_class=num_class, pretrained=pretrained).to(device)
    elif M_G == "RexUNet":
        model_g = rexunet.RexUNet(num_class=num_class, pretrained=pretrained).to(device)
    elif M_G == "Att_UNet":
        model_g = att_unet.Att_UNet(num_class=num_class).to(device)
    elif M_G == "UNet_PP":
        model_g = unet_pp.UNet_PP(num_class=num_class).to(device)
    model_d = model_discriminative.Discriminator().to(device)
    return model_g, model_d
コード例 #5
0
def predict(config):
    device = torch.device('cuda:0')
    model = unet.UNet(num_classes=config['num_classes'])
    
    check_point = os.path.join(config['save_model']['save_path'], 'unet.pth')
    transform = transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.ToTensor()
            # transforms.Normalize(mean=[0.485, 0.456, 0.406],
            #                      std=[0.229, 0.224, 0.225])
        ]
    )
    model.load_state_dict(torch.load(check_point), False)
    model.cuda()
    model.eval()
    #输出文件夹
    if os.path.exists(config['pre_dir']) is False:
        os.mkdir(config['pre_dir'])
    pre_base_path = os.path.join(config['pre_dir'], 'predict_unet')
    if os.path.exists(pre_base_path) is False:
        os.mkdir(pre_base_path)
    pre_mask_path = os.path.join(pre_base_path, 'mask')
    if os.path.exists(pre_mask_path) is False:
        os.mkdir(pre_mask_path)
    pre_vis_path = os.path.join(pre_base_path, 'vis')
    if os.path.exists(pre_vis_path) is False:
        os.mkdir(pre_vis_path)
    
    with open(config['img_txt'], 'r', encoding='utf-8') as f:
        for line in f.readlines():
            image_name, _ = line.strip().split('\t')
            im = np.asarray(Image.open(image_name))
            im = im.reshape((Height, Width, Img_channel))
            im = transform(im).float().cuda()
            im = im.reshape((1,Img_channel,Height,Width))

            output = model(im)
            _, pred = output.max(1)
            pred = pred.view(Height, Width)
            mask_im = pred.cpu().numpy().astype(np.uint8)

            file_name = image_name.split('/')[-1]
            save_label = os.path.join(pre_mask_path, file_name)
            cv2.imwrite(save_label, mask_im)
            print("写入{}成功".format(save_label))
            save_visual = os.path.join(pre_vis_path, file_name)
            print("开始写入{}".format(save_visual))
            translabeltovisual(save_label, save_visual)
            print("写入{}成功".format(save_visual))
コード例 #6
0
def main():
    root_dir = 'C:\Scripts\hubmap\code'

    data_dir = 'C:\Scripts\hubmap\\train\\tiled_thresholded_512'

    mean = [0.68912, 0.47454, 0.6486]
    std_dev = [0.13275, 0.23647, 0.15536]

    #full dataset with training images and masks
    dataset = dataloader.Dataset_Image_mask(data_dir, mean, std_dev)

    n_tot = dataset.len

    #SplitS full dataset into train set and test set
    train_test_split = 0.8
    train_count = int(train_test_split * n_tot)

    test_count = dataset.len - train_count

    train_idx = list(np.random.choice(range(n_tot), train_count,
                                      replace=False))
    test_idx = list(set(range(n_tot)) - set(train_idx))

    print(len(train_idx), len(test_idx),
          n_tot - len(train_idx) - len(test_idx))

    train_ds = torch.utils.data.Subset(dataset, train_idx)
    test_ds = torch.utils.data.Subset(dataset, test_idx)

    model = unet.UNet()

    config = json.load(open('config.json'))
    b_size = config["train_loader"]["args"]["batch_size"]
    train_loader = DataLoader(train_ds,
                              batch_size=b_size,
                              shuffle=True,
                              num_workers=0)
    b_size = config["val_loader"]["args"]["batch_size"]
    val_loader = DataLoader(test_ds,
                            batch_size=b_size,
                            shuffle=True,
                            num_workers=0)

    trainer = Trainer(model, loss.loss_fn, config, train_loader, val_loader)
    print(f"Trainining on device: {trainer.device}")

    trainer.train()
コード例 #7
0
def get_model_fn(pipeline_config, result_folder, dataset_info, eval_split_name,
                 num_gpu, eval_dir):

    if dataset_info is None:
        visualization_file_names = None
    else:
        file_names = dataset_info[
            standard_fields.PickledDatasetInfo.file_names][eval_split_name]
        np.random.shuffle(file_names)

        patient_ids = dataset_info[
            standard_fields.PickledDatasetInfo.patient_ids][eval_split_name]

        # Select one image per patient
        selected_files = dict()
        for file_name in file_names:
            patient_id = _extract_patient_id(file_name)
            assert (patient_id in patient_ids)
            if patient_id not in selected_files:
                selected_files[patient_id] = file_name

        num_visualizations = pipeline_config.eval_config.num_images_to_visualize
        if num_visualizations is None or num_visualizations == -1:
            num_visualizations = len(selected_files)
        else:
            num_visualizations = min(num_visualizations, len(selected_files))

        visualization_file_names = list(
            selected_files.values())[:num_visualizations]

    model_name = pipeline_config.model.WhichOneof('model_type')
    if model_name == 'unet':
        feature_extractor = unet.UNet(
            weight_decay=pipeline_config.train_config.weight_decay,
            conv_padding=pipeline_config.model.conv_padding,
            filter_sizes=pipeline_config.model.unet.filter_sizes)
        return functools.partial(
            _general_model_fn,
            pipeline_config=pipeline_config,
            result_folder=result_folder,
            dataset_info=dataset_info,
            feature_extractor=feature_extractor,
            num_gpu=num_gpu,
            visualization_file_names=visualization_file_names,
            eval_dir=eval_dir)
    else:
        assert (False)
コード例 #8
0
def net(net_params, rtn_level=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    state_dict_path = ''
    if model_name == 'unetsmall':
        model = unet.UNetSmall(
            net_params['global']['num_classes'],
            net_params['global']['number_of_bands'],
            net_params['models']['unetsmall']['dropout'],
            net_params['models']['unetsmall']['probability'])
        if net_params['models']['unetsmall']['pretrained']:
            state_dict_path = net_params['models']['unetsmall']['pretrained']
    elif model_name == 'unet':
        model = unet.UNet(net_params['global']['num_classes'],
                          net_params['global']['number_of_bands'],
                          net_params['models']['unet']['dropout'],
                          net_params['models']['unet']['probability'])
        if net_params['models']['unet']['pretrained']:
            state_dict_path = net_params['models']['unet']['pretrained']
    elif model_name == 'ternausnet':
        model = TernausNet.ternausnet(
            net_params['global']['num_classes'],
            net_params['models']['ternausnet']['pretrained'])
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(
            net_params['global']['num_classes'],
            net_params['global']['number_of_bands'],
            net_params['models']['unetsmall']['dropout'],
            net_params['models']['unetsmall']['probability'])
        if net_params['models']['unetsmall']['pretrained']:
            state_dict_path = net_params['models']['unetsmall']['pretrained']
    elif model_name == 'inception':
        model = inception.Inception3(net_params['global']['num_classes'],
                                     net_params['global']['number_of_bands'])
        if net_params['models']['inception']['pretrained']:
            state_dict_path = net_params['models']['inception']['pretrained']
    else:
        raise ValueError('The model name in the config.yaml is not defined.')

    if rtn_level:
        lvl = maxpool_level(model, net_params['global']['number_of_bands'],
                            256)
        return model, state_dict_path, lvl['MaxPoolCount']
    else:
        return model, state_dict_path
コード例 #9
0
ファイル: test.py プロジェクト: RJ2019/BuildingFootprints
def main(hyperparameters, options):
    # grab the hyperparameters and options for training
    data_set = options['dataset']
    in_channels = options['in_channels']
    n_classes = options['n_classes']
    pred_output = options['predictions']
    saved_model = options['saved_model']
    batch_size = hyperparameters['testing_batch_size']
    depth = hyperparameters['depth']
    wf = hyperparameters['wf']
    padding = hyperparameters['pad']
    batch_norm = hyperparameters['batch_norm']
    up_mode = hyperparameters['up_mode']

    # use the UNet model in models dir
    # https://github.com/jvanvugt/pytorch-unet
    model = unet.UNet(in_channels=in_channels,
                      n_classes=n_classes,
                      depth=depth,
                      wf=wf,
                      padding=padding,
                      batch_norm=batch_norm,
                      up_mode=up_mode)

    # load in the test dataset
    test_img, orig_dim, num_test, pad_size = test_data(data_set, depth,
                                                       padding)

    # set up the custom test class
    custom_test_class = dataset_class.CustomDatasetFromTif(test_img, [],
                                                           num_test,
                                                           test_set=True,
                                                           shuffle_data=False)

    # set up the test data loader
    test_loader = DataLoader(dataset=custom_test_class, batch_size=batch_size)

    # use GPU if available, https://pytorch.org/docs/stable/notes/cuda.html
    if torch.cuda.is_available():
        model = model.cuda()

    best_pred(model, saved_model, pred_output, data_set, batch_size, pad_size,
              orig_dim, num_test, test_loader)
コード例 #10
0
ファイル: engine.py プロジェクト: DVLP-CMATERJU/SegFast
def load_model(model_name, noc):
    if model_name == 'fcn':
        model = fcn.FCN8s(noc)
    if model_name == 'segnet':
        model = segnet.SegNet(3, noc)
    if model_name == 'pspnet':
        model = pspnet.PSPNet(noc)
    if model_name == 'unet':
        model = unet.UNet(noc)
    if model_name == 'segfast':
        model = segfast.SegFast(64, noc)
    if model_name == 'segfast_basic':
        model = segfast_basic.SegFast_Basic(64, noc)
    if model_name == 'segfast_mobile':
        model = segfast_mobile.SegFast_Mobile(noc)
    if model_name == 'segfast_v2_3':
        model = segfast_v2.SegFast_V2(64, noc, 3)
    if model_name == 'segfast_v2_5':
        model = segfast_v2.SegFast_V2(64, noc, 5)
    return model
コード例 #11
0
def train(config):

    # train配置
    device = torch.device('cuda:0')

    model = unet.UNet(num_classes=config['num_classes'])
    
    # model = nn.DataParallel(model, device_ids=[0, 1])
    model.to(device)

    logger = initLogger("unet")

    # loss
    criterion = nn.CrossEntropyLoss()

    # train data
    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),

            transforms.ToTensor()
            #输入图像是单通道
            #transforms.Normalize((0.5, ), (0.5, ))
        ]
    )
    dst_train = unet_dataset.UnetDataset(config['train_list'], transform=transform)
    dataloader_train = DataLoader(dst_train, shuffle=True, batch_size=config['batch_size'])

    # validation data
    transform = transforms.Compose(
        [
            #transforms.ToPILImage(),
            transforms.ToTensor()
            #输入图像是单通道
            #transforms.Normalize((0.5, ), (0.5, ))
        ]
    )
    dst_valid = unet_dataset.UnetDataset(config['test_list'], transform=transform)
    dataloader_valid = DataLoader(dst_valid, shuffle=False, batch_size=config['batch_size'])

    cur_acc = []
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=[config['momentum'], 0.999], weight_decay=config['weight_decay'])
    
    max_pixACC = 0.0
    for epoch in range(config['num_epoch']):
        epoch_start = time.time()
        # lr
        
        model.train()
        loss_sum = 0.0
        correct_sum = 0.0
        labeled_sum = 0.0
        inter_sum = 0.0
        unoin_sum = 0.0
        pixelAcc = 0.0
        IoU = 0.0
        tbar = tqdm(dataloader_train, ncols=100)
        for batch_idx, (data, target) in enumerate(tbar):
            tic = time.time()

            # data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss_sum += loss.item()
            loss.backward()
            optimizer.step()

            correct, labeled, inter, unoin = eval_metrics(output, target, config['num_classes'])
            correct_sum += correct
            labeled_sum += labeled
            inter_sum += inter
            unoin_sum += unoin
            pixelAcc = 1.0 * correct_sum / (np.spacing(1)+labeled_sum)
            IoU = 1.0 * inter_sum / (np.spacing(1) + unoin_sum)
            tbar.set_description('TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.4f} | bt {:.2f} et {:.2f}|'.format(
                epoch, loss_sum/((batch_idx+1)*config['batch_size']),
                pixelAcc, IoU.mean(),
                time.time()-tic, time.time()-epoch_start))
            cur_acc.append(pixelAcc)

        logger.info('TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} IOU {}  mIoU {:.4f} '.format(
            epoch, loss_sum / ((batch_idx + 1) * config['batch_size']),
            pixelAcc, toString(IoU), IoU.mean()))
            

        # val
        test_start = time.time()
        
        model.eval()
        loss_sum = 0.0
        correct_sum = 0.0
        labeled_sum = 0.0
        inter_sum = 0.0
        unoin_sum = 0.0
        pixelAcc = 0.0
        mIoU = 0.0
        tbar = tqdm(dataloader_valid, ncols=100)
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(tbar):
                tic = time.time()

                # data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                loss_sum += loss.item()

                correct, labeled, inter, unoin = eval_metrics(output, target, config['num_classes'])
                correct_sum += correct
                labeled_sum += labeled
                inter_sum += inter
                unoin_sum += unoin
                pixelAcc = 1.0 * correct_sum / (np.spacing(1) + labeled_sum)
                mIoU = 1.0 * inter_sum / (np.spacing(1) + unoin_sum)
                tbar.set_description('VAL ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.4f} | bt {:.2f} et {:.2f}|'.format(
                    epoch, loss_sum / ((batch_idx + 1) * config['batch_size']),
                    pixelAcc, mIoU.mean(),
                            time.time() - tic, time.time() - test_start))
            if pixelAcc > max_pixACC:
                max_pixACC = pixelAcc
                if os.path.exists(config['save_model']['save_path']) is False:
                    os.mkdir(config['save_model']['save_path'])
                torch.save(model.state_dict(), os.path.join(config['save_model']['save_path'], 'unet.pth'))
        logger.info('VAL ({}) | Loss: {:.3f} | Acc {:.2f} IOU {} mIoU {:.4f} |'.format(
            epoch, loss_sum / ((batch_idx + 1) * config['batch_size']),
            pixelAcc, toString(mIoU), mIoU.mean()))
コード例 #12
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path', net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'], True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        try:
            model = models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True, in_channels=num_bands,
                                                            num_classes=num_channels, aux_loss=None)
        except:
            assert num_bands==3, 'Edit torchvision scripts segmentation.py and resnet.py to build deeplabv3_resnet ' \
                                 'with more or less than 3 bands'
            model = models.segmentation.deeplabv3_resnet101(pretrained=False, progress=True,
                                                            num_classes=num_channels, aux_loss=None)
    else:
        raise ValueError(f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'], False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'], True)
        normalized = get_key_def('coordconv_normalized', net_params['global'], True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel', net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model, centered=centered, normalized=normalized, noise=noise,
                                                radius_channel=radius_channel, scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    elif pretrained and (model_name == ('deeplabv3_resnet101' or 'fcn_resnet101')):
        print(f'Retrieving coco checkpoint for {model_name}...\n')
        if model_name == 'deeplabv3_resnet101':  # default to pretrained on coco (21 classes)
            coco_model = models.segmentation.deeplabv3_resnet101(pretrained=True, progress=True, num_classes=21, aux_loss=None)
        else:
            coco_model = models.segmentation.fcn_resnet101(pretrained=True, progress=True, num_classes=21, aux_loss=None)
        checkpoint = coco_model.state_dict()
        # Place entire state_dict inside 'model' key for compatibility with the rest of GDL workflow
        temp_checkpoint = {}
        temp_checkpoint['model'] = {k: v for k, v in checkpoint.items()}
        del coco_model, checkpoint
        checkpoint = temp_checkpoint
    elif pretrained:
        warnings.warn(f'No pretrained checkpoint found for {model_name}.')
        checkpoint = None
    else:
        checkpoint = None

    return model, checkpoint, model_name
コード例 #13
0
def net(model_name: str,
        num_bands: int,
        num_channels: int,
        dontcare_val: int,
        num_devices: int,
        train_state_dict_path: str = None,
        pretrained: bool = True,
        dropout_prob: float = False,
        loss_fn: str = None,
        optimizer: str = None,
        class_weights: Sequence = None,
        net_params=None,
        conc_point: str = None,
        coordconv_params=None,
        inference_state_dict: str = None):
    """Define the neural net"""
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    pretrained = False if train_state_dict_path or inference_state_dict else pretrained
    dropout = True if dropout_prob else False
    model = None

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        if not num_bands == 3:
            raise NotImplementedError(msg)
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        if not num_bands == 3:
            raise NotImplementedError(msg)
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        if not (num_bands == 3 or num_bands == 4):
            raise NotImplementedError(msg)
        if num_bands == 3:
            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module(
                '4',
                nn.Conv2d(classifier[-1].in_channels,
                          num_channels,
                          kernel_size=(1, 1)))
        elif num_bands == 4:

            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)

            if conc_point == 'baseline':
                logging.info(
                    'Testing with 4 bands, concatenating at {}.'.format(
                        conc_point))
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy(
                )
                depth = np.expand_dims(
                    conv1[:, 1,
                          ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(
                    conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
                ###################
                # conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                # depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                # conv1 = np.append(conv1, depth, axis=1)
                # conv1 = torch.from_numpy(conv1).float()
                # model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                conc_point = 'conv1' if not conc_point else conc_point
                model = LayersEnsemble(model, conc_point=conc_point)

        logging.info(
            f'Finetuning pretrained deeplabv3 with {num_bands} input channels (imagery bands). '
            f'Concatenation point: "{conc_point}"')

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params'][
            'encoder_weights'] = "imagenet" if 'pretrained' in model_name.split(
                "_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])

    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', coordconv_params,
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', coordconv_params, True)
        normalized = get_key_def('coordconv_normalized', coordconv_params,
                                 True)
        noise = get_key_def('coordconv_noise', coordconv_params, None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     coordconv_params, False)
        scale = get_key_def('coordconv_scale', coordconv_params, 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference_state_dict:
        state_dict_path = inference_state_dict
        checkpoint = load_checkpoint(state_dict_path)

        return model, checkpoint, model_name

    else:

        if train_state_dict_path is not None:
            checkpoint = load_checkpoint(train_state_dict_path)
        else:
            checkpoint = None
        # list of GPU devices that are available and unused. If no GPUs, returns empty list
        gpu_devices_dict = get_device_ids(num_devices)
        num_devices = len(gpu_devices_dict.keys())
        logging.info(
            f"Number of cuda devices requested: {num_devices}. "
            f"Cuda devices available: {list(gpu_devices_dict.keys())}\n")
        if num_devices == 1:
            logging.info(
                f"Using Cuda device 'cuda:{list(gpu_devices_dict.keys())[0]}'")
        elif num_devices > 1:
            logging.info(
                f"Using data parallel on devices: {list(gpu_devices_dict.keys())[1:]}. "
                f"Main device: 'cuda:{list(gpu_devices_dict.keys())[0]}'")
            try:  # For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
                # DataParallel adds prefix 'module.' to state_dict keys
                model = nn.DataParallel(model,
                                        device_ids=list(
                                            gpu_devices_dict.keys()))
            except AssertionError:
                logging.warning(
                    f"Unable to use devices with ids {gpu_devices_dict.keys()}"
                    f"Trying devices with ids {list(range(len(gpu_devices_dict.keys())))}"
                )
                model = nn.DataParallel(
                    model,
                    device_ids=list(range(len(gpu_devices_dict.keys()))))
        else:
            logging.warning(
                f"No Cuda device available. This process will only run on CPU\n"
            )
        logging.info(
            f'Setting model, criterion, optimizer and learning rate scheduler...\n'
        )
        device = torch.device(
            f'cuda:{list(range(len(gpu_devices_dict.keys())))[0]}'
            if gpu_devices_dict else 'cpu')
        try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
            model.to(device)
        except AssertionError:
            logging.exception(f"Unable to use device. Trying device 0...\n")
            device = torch.device(f'cuda' if gpu_devices_dict else 'cpu')
            model.to(device)

        model, criterion, optimizer, lr_scheduler = set_hyperparameters(
            params=net_params,
            num_classes=num_channels,
            model=model,
            checkpoint=checkpoint,
            dontcare_val=dontcare_val,
            loss_fn=loss_fn,
            optimizer=optimizer,
            class_weights=class_weights,
            inference=inference_state_dict)
        criterion = criterion.to(device)

        return model, model_name, criterion, optimizer, lr_scheduler, device, gpu_devices_dict
コード例 #14
0
ファイル: sampler.py プロジェクト: light-dawn/Active_Learning
    def sample(self, dataloader):
        raise NotImplementedError


class EntropySampler:
    def __init__(self, budget):
        self.budget = budget

    def sample(self, dataloader):
        raise NotImplementedError



if __name__ == "__main__":
    with open("demo_cfg.json", "r") as f:
        config = json.loads(f.read())
    model = unet.UNet(n_channels=config["model"]["n_channels"], n_classes=config["model"]["n_classes"])
    sampler = OMedALSampler(100, model)
    layer = sampler.get_embedding_layer()
    print(layer)
    # print(len(layer))


        
        


            


コード例 #15
0
def net(net_params, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_classes = net_params['global']['num_classes']
    if num_classes == 1:
        warnings.warn(
            "config specified that number of classes is 1, but model will be instantiated"
            " with a minimum of two regardless (will assume that 'background' exists)"
        )
        num_classes = 2
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    state_dict_path = ''
    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_classes,
                               net_params['global']['number_of_bands'],
                               net_params['training']['dropout'],
                               net_params['training']['dropout_prob'])
    elif model_name == 'unet':
        model = unet.UNet(num_classes, net_params['global']['number_of_bands'],
                          net_params['training']['dropout'],
                          net_params['training']['dropout_prob'])
    elif model_name == 'ternausnet':
        assert net_params['global']['number_of_bands'] == 3, msg
        model = TernausNet.ternausnet(num_classes)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(
            num_classes, net_params['global']['number_of_bands'],
            net_params['training']['dropout'],
            net_params['training']['dropout_prob'])
    elif model_name == 'inception':
        model = inception.Inception3(num_classes,
                                     net_params['global']['number_of_bands'])
    elif model_name == 'fcn_resnet101':
        assert net_params['global']['number_of_bands'] == 3, msg
        coco_model = models.segmentation.fcn_resnet101(pretrained=True,
                                                       progress=True,
                                                       num_classes=21,
                                                       aux_loss=None)
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_classes,
                                                  aux_loss=None)
        chopped_dict = chop_layer(coco_model.state_dict(),
                                  layer_names=['classifier.4'])
        del coco_model
        # load the new state dict
        # When strict=False, allows to load only the variables that are identical between the two models irrespective of
        # whether one is subset/superset of the other.
        model.load_state_dict(chopped_dict, strict=False)
    elif model_name == 'deeplabv3_resnet101':
        assert net_params['global']['number_of_bands'] == 3, msg
        # pretrained on coco (21 classes)
        coco_model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                             progress=True,
                                                             num_classes=21,
                                                             aux_loss=None)
        model = models.segmentation.deeplabv3_resnet101(
            pretrained=False,
            progress=True,
            num_classes=num_classes,
            aux_loss=None)
        chopped_dict = chop_layer(coco_model.state_dict(),
                                  layer_names=['classifier.4'])
        del coco_model
        model.load_state_dict(chopped_dict, strict=False)
    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if net_params['training']['state_dict_path']:
        state_dict_path = net_params['training']['state_dict_path']
        checkpoint = load_checkpoint(state_dict_path)
    elif inference:
        state_dict_path = net_params['inference']['state_dict_path']
        checkpoint = load_checkpoint(state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
コード例 #16
0
ファイル: test.py プロジェクト: chitwansaharia/EyeInTheSky
def main():
	args = parser.parse_args()
	mask_dim = args.crop_dim
	
	device = torch.device("cuda") if torch.cuda.is_available() else \
						 torch.device("cpu") # Setting device

	model_path = os.path.join(args.data_dir, 'saved_models', args.model)
	# Load the pretrained model
	saved_data = torch.load(model_path)
	if args.ternaus:
		model = TernausNetV2(num_classes = num_classes).to(device)
	else:
		model = unet.UNet(args.num_channels, num_classes).to(device)
	model.load_state_dict(saved_data["model_state_dict"])
	done_epochs = saved_data["epochs"]
	best_metric = saved_data["best_metric"]

	predicted_labels = []
	predicted_labels_images = {}
	true_labels = []

	if not os.path.exists(args.out_dir):
		os.mkdir(args.out_dir)

	for file in os.listdir(args.test_data):
		if not file.endswith('.tif'):
			continue
		print("Processing file: {}".format(file))
		
		base_image = tiff.imread(os.path.join(args.test_data, file))
		base_image = base_image.astype(float)
		num_channels = base_image.shape[2]
		for i in range(num_channels):
			base_image[:,:,i] = (base_image[:,:,i]-mean[i])/std[i]

		margin = mask_dim // 10
		if args.crop_end:
			base_image = np.moveaxis(np.array([np.pad(base_image[:,:,channel],\
			 ((margin,margin),(margin,margin)), 'reflect') \
					for channel in range(num_channels)]), 0, 2)

		
		pred_image = return_pred_image(args, base_image, model, device)
		predicted_labels_images[file] = pred_image

		pred_image = np.argmax(pred_image, 2)

		color_image = np.zeros([base_image.shape[0], base_image.shape[1], 3])
		for ix,iy in np.ndindex(pred_image.shape):
			color_image[ix, iy, :] = map_dict[pred_image[ix, iy]]
		im = Image.fromarray(np.uint8(color_image))
		
		im.save(os.path.join(args.out_dir, file))
		
		if args.label_data is not None:
			file_name = file.split('.')[0]
			label_path = os.path.join(args.label_data, '{}.npy'.\
								format(file_name))
			predicted_labels.extend(list(pred_image.reshape([-1])))
			flat_true_label = np.expand_dims(np.load(label_path),
								axis=2).reshape(-1)
			true_labels.extend(list(flat_true_label))
	
	
	if args.label_data is not None:
		predicted_labels = np.array(predicted_labels)
		true_labels = np.array(true_labels)
		print("Accuracy on these images : {}".format(
			sklearn.metrics.accuracy_score(predicted_labels, true_labels)))
		print("Cohen kappa score on these images : {}".format(
			sklearn.metrics.cohen_kappa_score(predicted_labels, true_labels)))			
		print("Confusion matrix on these images : {}".format(
			sklearn.metrics.confusion_matrix(true_labels, predicted_labels)))
		print("Precision recall class F1 : {}".format(
			sklearn.metrics.precision_recall_fscore_support(true_labels, 
												predicted_labels)))

	if args.pkl_dir is not None:
		with open(os.path.join(args.pkl_dir, args.model+'.pkl'),'wb') as f:
			pkl.dump(predicted_labels_images, f)
コード例 #17
0
ファイル: utils.py プロジェクト: light-dawn/Active_Learning
 def unet(conf):
     return unet.UNet(conf["n_channels"], conf["n_classes"])
コード例 #18
0
ファイル: train_RV.py プロジェクト: XiaowenK/UNet_Family
def train():

    args = parser.parse_args()

    for k in range(1, args.fold + 1):
        print("========== Fold {} ==========".format(k))
        # ---- setting model & optimizer ----
        if args.arch == "UNet":
            model = unet.UNet(args.num_class).to(device)
        elif args.arch == "Att_UNet":
            model = att_unet.Att_UNet(args.num_class).to(device)
        elif args.arch == "UNet_PP":
            model = unet_pp.UNet_PP(args.num_class).to(device)
        elif args.arch == "ResUNet50":
            model = resunet.ResUNet50(args.num_class,
                                      args.pretrained).to(device)
        elif args.arch == "ResUNet101":
            model = resunet.ResUNet101(args.num_class,
                                       args.pretrained).to(device)
        elif args.arch == "ResUNext101":
            model = resunext.ResUNext101(args.num_class,
                                         args.pretrained).to(device)

        if args.optm == "SGD":
            optimizer = optim.SGD(model.parameters(), lr=args.lr)
        elif args.optm == "Adam":
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))
        scheduler = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=0.5,
                                      patience=5,
                                      verbose=True,
                                      eps=1e-6)

        if args.apex:
            amp.register_float_function(torch, 'sigmoid')
            amp.register_float_function(F, 'softmax')
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

        model = torch.nn.DataParallel(model)
        model.train()

        # ---- start training ----
        path_txt_train = r"{}/fold_{}/train.txt".format(args.txt, k)
        loss_min = 100
        for epoch in range(1, args.epoch + 1):
            # ---- timer ----
            starttime = datetime.datetime.now()
            # ---- loading data ----
            dataset = RetinaVesselDataset(args.data, path_txt_train,
                                          args.height, args.width,
                                          args.pretrained)
            data = DataLoader(dataset=dataset,
                              batch_size=args.bs,
                              shuffle=True,
                              num_workers=12,
                              pin_memory=True)
            # ---- loop for all train data ----
            loss_train_sum = 0
            for step, batch_data in enumerate(data):
                # ---- inputs & masks ----
                inputs = batch_data['image'].to(device, dtype=torch.float)
                masks = batch_data['mask'].to(device, dtype=torch.float)
                # ---- fp ----
                outputs = model(inputs)
                outputs = torch.sigmoid(outputs)
                # ---- bp ----
                loss = soft_dice_coef_loss(outputs, masks)
                # loss = dice_coef_loss(outputs, masks)
                loss_train_sum += loss.item()
                if args.apex:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            # ---- train loss ----
            loss_train = loss_train_sum / len(data)
            # ---- validation ----
            loss_val = validation(model, k, args)
            scheduler.step(loss_val)
            # ---- saving ckpt ----
            if loss_val < loss_min:
                loss_min = loss_val
                print("Best model saved at epoch {}!".format(epoch))
                torch.save(model.state_dict(),
                           r"./checkpoints/fold_{}.pth.tar".format(k))
            # ---- timer ----
            endtime = datetime.datetime.now()
            elapsed = (endtime - starttime).seconds
            # ---- printing ----
            print(
                "fold #{}, epoch #{}, train: {:.4f}, val: {:.4f}, elapsed: {}s"
                .format(k, epoch, loss_train, loss_val, elapsed))
            print('-' * 60)
コード例 #19
0
    #IMG_CHANNELS = 3
    TRAIN_PATH = '../data/stage1_train/'
    TEST_PATH = '../data/stage1_test/'

    #seed = 42
    #random.seed = seed
    #np.random.seed = seed

    # Get train and test IDs
    X_train, Y_train, X_test = data.load_data(TRAIN_PATH, TEST_PATH)

    img_size = 128

    train_dataset = augment.NucleusDataset(X_train, Y_train)
    train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True)
    model = unet.UNet(3, 2)
    use_cuda = False
    gpu_ids = [0, 1]
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    mean = float(np.mean([np.mean(np.array(x)[:, :, :3]) for x in X_test]))
    std = float(np.std([np.std(np.array(x)[:, :, :3]) for x in X_test]))

    # Predict on train, val and test
    n_epochs = 1
    for _ in tqdm(range(n_epochs)):
        for data in train_dataloader:
            train_dataset.gen_params(img_size)
            optimizer.zero_grad()
            if use_cuda:
                X = Variable(data['X']).cuda()
コード例 #20
0
ファイル: train.py プロジェクト: parkchaesong/DeepCaffeine

if __name__ == "__main__":

    opt = args
    print(opt)

    train_data_loader, valid_data_loader = get_data_loader(opt)

    if not os.path.exists(opt.log_dir):
        os.makedirs(opt.log_dir)

    log_file = os.path.join(opt.log_dir, '%s_log.csv' % (opt.model))

    if opt.model == 'unet':
        net = unet.UNet(opt.num_class + 1)

    loss_criterion = set_loss(opt)

    print(net)

    print('===> Setting GPU')
    print("CUDA Available", torch.cuda.is_available())

    if opt.use_cuda and torch.cuda.is_available():
        opt.use_cuda = True
        opt.device = 'cuda'
    else:
        opt.use_cuda = False
        opt.device = 'cpu'
コード例 #21
0
def test():

    args = parser.parse_args()

    # ---- start testing ----
    list_loss_fold = []
    list_dsc_fold = []
    for k in range(1, args.fold + 1):
        print("========== Fold {} ==========".format(k))
        # ---- setting model ----
        if args.arch == "UNet":
            model = unet.UNet(args.num_class).to(device)
        elif args.arch == "Att_UNet":
            model = att_unet.Att_UNet(args.num_class).to(device)
        elif args.arch == "UNet_PP":
            model = unet_pp.UNet_PP(args.num_class).to(device)
        elif args.arch == "ResUNet50":
            model = resunet.ResUNet50(args.num_class,
                                      args.pretrained).to(device)
        elif args.arch == "ResUNet101":
            model = resunet.ResUNet101(args.num_class,
                                       args.pretrained).to(device)
        elif args.arch == "ResUNext101":
            model = resunext.ResUNext101(args.num_class,
                                         args.pretrained).to(device)

        model = torch.nn.DataParallel(model).to(device)

        model_ckpt = torch.load(r"./checkpoints/fold_{}.pth.tar".format(k))
        model.load_state_dict(model_ckpt)
        model.eval()
        for param in model.parameters():
            param.requires_grad = False

        # ---- loading data ----
        path_txt_test = r"{}/fold_{}/test.txt".format(args.txt, k)
        dataset = RetinaVesselDataset(args.data, path_txt_test, args.height,
                                      args.width, args.pretrained)
        data = DataLoader(dataset=dataset,
                          batch_size=args.bs,
                          num_workers=12,
                          pin_memory=True)
        # ---- init loss ----
        loss_sum = 0
        dsc_sum = 0
        # ---- start validating ----
        for idx_batch, batch_data in enumerate(tqdm(data)):
            # ---- inputs & masks ----
            inputs = batch_data['image'].to(device, dtype=torch.float)
            masks = batch_data['mask'].to(device, dtype=torch.float)
            # ---- fp ----
            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)
            # ---- bp ----
            loss = soft_dice_coef_loss(outputs, masks)
            # loss = dice_coef_loss(outputs, masks)
            dsc = dice_coef(outputs, masks)
            loss_sum += loss.item()
            dsc_sum += dsc.item()
        # ---- average loss ----
        loss_test = loss_sum / len(data)
        list_loss_fold.append(loss_test)
        # ---- average DSC ----
        dsc_test = dsc_sum / len(data)
        list_dsc_fold.append(dsc_test)

    # ---- Dice Coefficient (DSC) and Standard Deviation ----
    mean_loss, std_dev_loss = calMeanStdDev(list_loss_fold)
    mean_dsc, std_dev_dsc = calMeanStdDev(list_dsc_fold)
    print("Loss: Mean({:.3f}), Standard Deviation({:.3f}) ".format(
        mean_loss, std_dev_loss))
    print("DSC: Mean({:.3f}), Standard Deviation({:.3f}) ".format(
        mean_dsc, std_dev_dsc))
コード例 #22
0
 def __init__(self, args):
     self.model = unet.UNet(args.in_ch, args.out_ch, args.base_kernel)
コード例 #23
0
ファイル: train.py プロジェクト: chitwansaharia/EyeInTheSky
 # Training and Validation Dataset Loader
 train_loader = torch.utils.data.DataLoader(SatelliteDataset(
     train_x_dir, train_y_dir, root_dir, args.crop_dim, args.num_channels,
     args.contrast_enhance, args.gaussian_blur, args.rescale_intensity),
                                            batch_size=args.batch_size,
                                            shuffle=True)
 val_loader = torch.utils.data.DataLoader(SatelliteDataset(
     val_x_dir, val_y_dir, root_dir, args.crop_dim, args.num_channels,
     args.contrast_enhance, args.gaussian_blur, args.rescale_intensity),
                                          batch_size=args.batch_size,
                                          shuffle=False)
 # Training for a single class
 if args.train_per_class:
     loss_criterion = nn.CrossEntropyLoss(
         torch.tensor([1, args.class_weight]).to(device))
     model = unet.UNet(args.num_channels, 2)
 # Using Ternaus Net Architecture
 elif args.ternaus:
     model = ternausnet2.TernausNetV2(num_classes)
     # Using pretrained weights
     state = torch.load('./deepglobe_buildings.pt')
     state = {
         key.replace('module.', '').replace('bn.', ''): value
         for key, value in state['model'].items()
         if key.startswith('module.conv')
     }
     model_dict = model.state_dict()
     model_dict.update(state)
     model.load_state_dict(model_dict)
     model.conv1_correct.conv1_correct.weight.data[:, :
                                                   3, :, :] = model.conv1.conv1.weight.data[:, :
コード例 #24
0
ファイル: train.py プロジェクト: bispl-lab/k-space-recon
valdata = mri_dataset(val_file_dir)

batch_size = 1

train_loader = torch.utils.data.DataLoader(dataset=traindata,
                                           batch_size=batch_size,
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(dataset=valdata,
                                         batch_size=batch_size,
                                         shuffle=True)

# Mount unet on gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = unet.UNet(2, 2)
model.to(device)

print(summary(model, input_size=(2, 720, 720)))


def train_net(model,
              epochs=5,
              lr=0.001,
              save_cp=True,
              trainloader=None,
              valloader=None,
              sample_k=None):
    dir_checkpoint = '/home/harry/fastmri/CheckPoint/'

    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
コード例 #25
0
def net(net_params, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_classes = net_params['global']['num_classes']
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    state_dict_path = ''
    if model_name == 'unetsmall':
        model = unet.UNetSmall(
            num_classes, net_params['global']['number_of_bands'],
            net_params['models']['unetsmall']['dropout'],
            net_params['models']['unetsmall']['probability'])
        if net_params['models']['unetsmall']['pretrained']:
            state_dict_path = net_params['models']['unetsmall']['pretrained']
    elif model_name == 'unet':
        model = unet.UNet(num_classes, net_params['global']['number_of_bands'],
                          net_params['models']['unet']['dropout'],
                          net_params['models']['unet']['probability'])
        if net_params['models']['unet']['pretrained']:
            state_dict_path = net_params['models']['unet']['pretrained']
    elif model_name == 'ternausnet':
        model = TernausNet.ternausnet(
            num_classes, net_params['models']['ternausnet']['pretrained'])
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(
            num_classes, net_params['global']['number_of_bands'],
            net_params['models']['unetsmall']['dropout'],
            net_params['models']['unetsmall']['probability'])
        if net_params['models']['unetsmall']['pretrained']:
            state_dict_path = net_params['models']['unetsmall']['pretrained']
    elif model_name == 'inception':
        model = inception.Inception3(num_classes,
                                     net_params['global']['number_of_bands'])
        if net_params['models']['inception']['pretrained']:
            state_dict_path = net_params['models']['inception']['pretrained']
    elif model_name == 'fcn_resnet101':
        assert net_params['global']['number_of_bands'], msg
        coco_model = models.segmentation.fcn_resnet101(pretrained=True,
                                                       progress=True,
                                                       num_classes=21,
                                                       aux_loss=None)
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_classes,
                                                  aux_loss=None)
        chopped_dict = chop_layer(coco_model.state_dict(),
                                  layer_name='classifier.4')
        del coco_model
        model.load_state_dict(chopped_dict,
                              strict=False)  # load the new state dict
        if net_params['models']['fcn_resnet101']['pretrained']:
            state_dict_path = net_params['models']['fcn_resnet101'][
                'pretrained']
    elif model_name == 'deeplabv3_resnet101':
        assert net_params['global']['number_of_bands'], msg
        # pretrained on coco (21 classes)
        coco_model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                             progress=True,
                                                             num_classes=21,
                                                             aux_loss=None)
        model = models.segmentation.deeplabv3_resnet101(
            pretrained=False,
            progress=True,
            num_classes=num_classes,
            aux_loss=None)
        chopped_dict = chop_layer(coco_model.state_dict(),
                                  layer_name='classifier.4')
        del coco_model
        # load the new state dict
        model.load_state_dict(
            chopped_dict, strict=False
        )  # When strict=False, allows to load only the variables that
        # are identical between the two models irrespective of whether one is subset/superset of the other.

        if net_params['models']['deeplabv3_resnet101']['pretrained']:
            state_dict_path = net_params['models']['deeplabv3_resnet101'][
                'pretrained']
    else:
        raise ValueError('The model name in the config.yaml is not defined.')
    if inference:
        state_dict_path = net_params['inference']['state_dict_path']

    return model, state_dict_path, model_name
コード例 #26
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path',
                                        net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'],
                             True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    # TODO: find a way to maybe implement it in classification one day
    if 'concatenate_depth' in net_params['global']:
        # Read the concatenation point
        conc_point = net_params['global']['concatenate_depth']

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=pretrained,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module(
                '4',
                nn.Conv2d(classifier[-1].in_channels,
                          num_channels,
                          kernel_size=(1, 1)))
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            print('Testing with 4 bands, concatenating at {}.'.format(
                conc_point))

            model = models.segmentation.deeplabv3_resnet101(
                pretrained=pretrained, progress=True)

            if conc_point == 'baseline':
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy(
                )
                depth = np.expand_dims(
                    conv1[:, 1,
                          ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(
                    conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4',
                    nn.Conv2d(classifier[-1].in_channels,
                              num_channels,
                              kernel_size=(1, 1)))
                ###################
                #conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                #depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                #conv1 = np.append(conv1, depth, axis=1)
                #conv1 = torch.from_numpy(conv1).float()
                #model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                model = LayersEnsemble(model, conc_point=conc_point)

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params'][
            'encoder_weights'] = "imagenet" if 'pretrained' in model_name.split(
                "_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])

    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(
        ), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(
        ), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
コード例 #27
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path', net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'], True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)
    dontcare_val = get_key_def("ignore_index", net_params["training"], -1)
    num_devices = net_params['global']['num_gpus']

    if dontcare_val == 0:
        warnings.warn("The 'dontcare' value (or 'ignore_index') used in the loss function cannot be zero;"
                      " all valid class indices should be consecutive, and start at 0. The 'dontcare' value"
                      " will be remapped to -1 while loading the dataset, and inside the config from now on.")
        net_params["training"]["ignore_index"] = -1

    # TODO: find a way to maybe implement it in classification one day
    if 'concatenate_depth' in net_params['global']:
        # Read the concatenation point
        conc_point = net_params['global']['concatenate_depth']

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False, progress=True, num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=pretrained, progress=True)
            classifier = list(model.classifier.children())
            model.classifier = nn.Sequential(*classifier[:-1])
            model.classifier.add_module('4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1)))
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            print('Testing with 4 bands, concatenating at {}.'.format(conc_point))

            model = models.segmentation.deeplabv3_resnet101(pretrained=pretrained, progress=True)

            if conc_point=='baseline':
                conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                depth = np.expand_dims(conv1[:, 1, ...], axis=1)  # reuse green weights for infrared.
                conv1 = np.append(conv1, depth, axis=1)
                conv1 = torch.from_numpy(conv1).float()
                model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                    '4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1))
                )
            else:
                classifier = list(model.classifier.children())
                model.classifier = nn.Sequential(*classifier[:-1])
                model.classifier.add_module(
                        '4', nn.Conv2d(classifier[-1].in_channels, num_channels, kernel_size=(1, 1))
                )
                ###################
                #conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
                #depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
                #conv1 = np.append(conv1, depth, axis=1)
                #conv1 = torch.from_numpy(conv1).float()
                #model.backbone._modules['conv1'].weight = nn.Parameter(conv1, requires_grad=True)
                ###################
                model = LayersEnsemble(model, conc_point=conc_point)

    elif model_name in lm_smp.keys():
        lsmp = lm_smp[model_name]
        # TODO: add possibility of our own weights
        lsmp['params']['encoder_weights'] = "imagenet" if 'pretrained' in model_name.split("_") else None
        lsmp['params']['in_channels'] = num_bands
        lsmp['params']['classes'] = num_channels
        lsmp['params']['activation'] = None

        model = lsmp['fct'](**lsmp['params'])


    else:
        raise ValueError(f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'], False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'], True)
        normalized = get_key_def('coordconv_normalized', net_params['global'], True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel', net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model, centered=centered, normalized=normalized, noise=noise,
                                                radius_channel=radius_channel, scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)

        return model, checkpoint, model_name

    else:

        if train_state_dict_path is not None:
            assert Path(train_state_dict_path).is_file(), f'Could not locate checkpoint at {train_state_dict_path}'
            checkpoint = load_checkpoint(train_state_dict_path)
        else:
            checkpoint = None
        assert num_devices is not None and num_devices >= 0, "missing mandatory num gpus parameter"
        # list of GPU devices that are available and unused. If no GPUs, returns empty list
        lst_device_ids = get_device_ids(num_devices) if torch.cuda.is_available() else []
        num_devices = len(lst_device_ids) if lst_device_ids else 0
        device = torch.device(f'cuda:{lst_device_ids[0]}' if torch.cuda.is_available() and lst_device_ids else 'cpu')
        print(f"Number of cuda devices requested: {net_params['global']['num_gpus']}. Cuda devices available: {lst_device_ids}\n")
        if num_devices == 1:
            print(f"Using Cuda device {lst_device_ids[0]}\n")
        elif num_devices > 1:
            print(f"Using data parallel on devices: {str(lst_device_ids)[1:-1]}. Main device: {lst_device_ids[0]}\n") # TODO: why are we showing indices [1:-1] for lst_device_ids?
            try:  # For HPC when device 0 not available. Error: Invalid device id (in torch/cuda/__init__.py).
                model = nn.DataParallel(model,
                                        device_ids=lst_device_ids)  # DataParallel adds prefix 'module.' to state_dict keys
            except AssertionError:
                warnings.warn(f"Unable to use devices {lst_device_ids}. Trying devices {list(range(len(lst_device_ids)))}")
                device = torch.device('cuda:0')
                lst_device_ids = range(len(lst_device_ids))
                model = nn.DataParallel(model,
                                        device_ids=lst_device_ids)  # DataParallel adds prefix 'module.' to state_dict keys
        else:
            warnings.warn(f"No Cuda device available. This process will only run on CPU\n")
        tqdm.write(f'Setting model, criterion, optimizer and learning rate scheduler...\n')
        try:  # For HPC when device 0 not available. Error: Cuda invalid device ordinal.
            model.to(device)
        except RuntimeError:
            warnings.warn(f"Unable to use device. Trying device 0...\n")
            device = torch.device(f'cuda:0' if torch.cuda.is_available() and lst_device_ids else 'cpu')
            model.to(device)

        model, criterion, optimizer, lr_scheduler = set_hyperparameters(net_params, num_channels, model, checkpoint, dontcare_val)
        criterion = criterion.to(device)

        return model, model_name, criterion, optimizer, lr_scheduler
コード例 #28
0
def net(net_params, num_channels, inference=False):
    """Define the neural net"""
    model_name = net_params['global']['model_name'].lower()
    num_bands = int(net_params['global']['number_of_bands'])
    msg = f'Number of bands specified incompatible with this model. Requires 3 band data.'
    train_state_dict_path = get_key_def('state_dict_path',
                                        net_params['training'], None)
    pretrained = get_key_def('pretrained', net_params['training'],
                             True) if not inference else False
    dropout = get_key_def('dropout', net_params['training'], False)
    dropout_prob = get_key_def('dropout_prob', net_params['training'], 0.5)

    if model_name == 'unetsmall':
        model = unet.UNetSmall(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'unet':
        model = unet.UNet(num_channels, num_bands, dropout, dropout_prob)
    elif model_name == 'ternausnet':
        assert num_bands == 3, msg
        model = TernausNet.ternausnet(num_channels)
    elif model_name == 'checkpointed_unet':
        model = checkpointed_unet.UNetSmall(num_channels, num_bands, dropout,
                                            dropout_prob)
    elif model_name == 'inception':
        model = inception.Inception3(num_channels, num_bands)
    elif model_name == 'fcn_resnet101':
        assert num_bands == 3, msg
        model = models.segmentation.fcn_resnet101(pretrained=False,
                                                  progress=True,
                                                  num_classes=num_channels,
                                                  aux_loss=None)
    elif model_name == 'deeplabv3_resnet101':
        assert (num_bands == 3 or num_bands == 4), msg
        if num_bands == 3:
            print('Finetuning pretrained deeplabv3 with 3 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                            progress=True,
                                                            aux_loss=None)
            model.classifier = common.DeepLabHead(2048, num_channels)
        elif num_bands == 4:
            print('Finetuning pretrained deeplabv3 with 4 bands')
            model = models.segmentation.deeplabv3_resnet101(pretrained=True,
                                                            progress=True,
                                                            aux_loss=None)
            conv1 = model.backbone._modules['conv1'].weight.detach().numpy()
            depth = np.random.uniform(low=-1, high=1, size=(64, 1, 7, 7))
            conv1 = np.append(conv1, depth, axis=1)
            conv1 = torch.from_numpy(conv1).float()
            model.backbone._modules['conv1'].weight = nn.Parameter(
                conv1, requires_grad=True)
            model.classifier = common.DeepLabHead(2048, num_channels)
    else:
        raise ValueError(
            f'The model name {model_name} in the config.yaml is not defined.')

    coordconv_convert = get_key_def('coordconv_convert', net_params['global'],
                                    False)
    if coordconv_convert:
        centered = get_key_def('coordconv_centered', net_params['global'],
                               True)
        normalized = get_key_def('coordconv_normalized', net_params['global'],
                                 True)
        noise = get_key_def('coordconv_noise', net_params['global'], None)
        radius_channel = get_key_def('coordconv_radius_channel',
                                     net_params['global'], False)
        scale = get_key_def('coordconv_scale', net_params['global'], 1.0)
        # note: this operation will not attempt to preserve already-loaded model parameters!
        model = coordconv.swap_coordconv_layers(model,
                                                centered=centered,
                                                normalized=normalized,
                                                noise=noise,
                                                radius_channel=radius_channel,
                                                scale=scale)

    if inference:
        state_dict_path = net_params['inference']['state_dict_path']
        assert Path(net_params['inference']['state_dict_path']).is_file(
        ), f"Could not locate {net_params['inference']['state_dict_path']}"
        checkpoint = load_checkpoint(state_dict_path)
    elif train_state_dict_path is not None:
        assert Path(train_state_dict_path).is_file(
        ), f'Could not locate checkpoint at {train_state_dict_path}'
        checkpoint = load_checkpoint(train_state_dict_path)
    else:
        checkpoint = None

    return model, checkpoint, model_name
コード例 #29
0
ファイル: train.py プロジェクト: RJ2019/BuildingFootprints
def main(hyperparameters, options):
    # grab the hyperparameters and options for training
    data_set = options['dataset']
    in_channels = options['in_channels']
    n_classes = options['n_classes']
    augment = options['augment']
    class_weights = hyperparameters['class_weights']
    num_epochs = hyperparameters['epochs']
    learning_rate = hyperparameters['learn_rate']
    lr_change = hyperparameters['lr_change']
    training_batch_size = hyperparameters['training_batch_size']
    testing_batch_size = hyperparameters['testing_batch_size']
    depth = hyperparameters['depth']
    wf = hyperparameters['wf']
    padding = hyperparameters['pad']
    batch_norm = hyperparameters['batch_norm']
    up_mode = hyperparameters['up_mode']

    print(
        """Running model with epochs={}, learning_rate={}, training_batch_size={},
testing_batch_size={}, in_channels={}, n_classes={}, depth={}, wf={}, padding={},
batch_norm={}, up_Mode={}, augment={}""".format(num_epochs, learning_rate,
                                                training_batch_size,
                                                testing_batch_size,
                                                in_channels, n_classes, depth,
                                                wf, padding, batch_norm,
                                                up_mode, augment),
        flush=True)

    # use the UNet model in models dir
    # https://github.com/jvanvugt/pytorch-unet
    model = unet.UNet(in_channels=in_channels,
                      n_classes=n_classes,
                      depth=depth,
                      wf=wf,
                      padding=padding,
                      batch_norm=batch_norm,
                      up_mode=up_mode)

    # use GPU if available, https://pytorch.org/docs/stable/notes/cuda.html
    if torch.cuda.is_available():
        model = model.cuda()

    # set up the optimizer for our model
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 amsgrad=True)

    # give more weight to building predictions
    loss_weights = torch.tensor(class_weights)
    # use GPU if available
    if torch.cuda.is_available():
        loss_weights = loss_weights.cuda()

    # set up the loss with the weights that were specified
    loss = torch.nn.CrossEntropyLoss(weight=loss_weights)

    # starting epoch
    start_epoch = 1

    # best validation loss
    best_loss = 10000

    # load in the training dataset
    train_img, train_label, orig_dim, num_train, pad_size = train_data(
        data_set, depth, padding, augment, current_set='training')

    # set up the custom training class
    custom_training_class = dataset_class.CustomDatasetFromTif(
        train_img, train_label, num_train)

    # set up the training data loader
    training_loader = DataLoader(dataset=custom_training_class,
                                 batch_size=training_batch_size)

    valid_img, valid_label, orig_dim, num_valid, pad_size = train_data(
        data_set, depth, padding, augment, current_set='validation')

    # set up the custom validation class
    custom_validation_class = dataset_class.CustomDatasetFromTif(
        valid_img, valid_label, num_valid)

    # set up the validation data loader
    validation_loader = DataLoader(dataset=custom_validation_class,
                                   batch_size=testing_batch_size)

    # loop through all of the epochs
    for epoch in range(start_epoch, num_epochs + 1):
        print("Epoch " + str(epoch), flush=True)

        # adjust the learning rate after a certain amount of epochs
        if epoch % lr_change == 0:
            adjust_lr(epoch, learning_rate, optimizer, lr_change)

        # run train and valid
        train(model, training_loader, optimizer, loss, pad_size, orig_dim,
              num_train)
        best_loss = valid(model, validation_loader, loss, best_loss, pad_size,
                          orig_dim, num_valid)

        # shuffle the training dataset after every epoch
        custom_training_class = dataset_class.CustomDatasetFromTif(
            train_img, train_label, num_train)

        training_loader = DataLoader(dataset=custom_training_class,
                                     batch_size=training_batch_size)