def train_boots(dataset_name: str, network_architecture: str,
                learning_rate: float, epochs: int, batch_size: int,
                horizontal_flip: float, vertical_flip: float,
                unet_filters: int, convolutions: int, plot: bool):
    """Train chosen model on selected dataset."""
    # use GPU if avilable
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    dataset = {}  # training and validation HDF5-based datasets
    dataloader = {}  # training and validation dataloaders

    # only UCSD dataset provides greyscale images instead of RGB
    input_channels = 1 if dataset_name == 'ucsd' else 3

    # prob
    epochs__ = 'epochs=' + str(epochs)
    batch_size__ = 'batch=' + str(batch_size)
    horizontal_flip__ = 'hf=' + str(horizontal_flip)
    vertical_flip__ = 'vf=' + str(vertical_flip)
    unet_filters__ = 'uf=' + str(unet_filters)
    convolutions__ = "conv" + str(convolutions)

    dirr = 'boots_results_' + dataset_name

    Path(dirr).mkdir(parents=True, exist_ok=True)

    # if plot flag is on, create a live plot (to be updated by Looper)
    if plot:
        pyplot.ion()
        fig, plots = pyplot.subplots(nrows=2, ncols=2)
    else:
        plots = [None] * 2

    for mode in ['train', 'valid']:
        # expected HDF5 files in dataset_name/(train | valid).h5
        data_path = os.path.join(dataset_name, f"{mode}.h5")
        # turn on flips only for training dataset
        dataset[mode] = H5Dataset(data_path,
                                  horizontal_flip if mode == 'train' else 0,
                                  vertical_flip if mode == 'train' else 0)

    #train_indices = torch.zeros_like(dataset[mode].shape[0])

    n_samples = len(dataset['train'])

    #print("******", n_samples)
    sampling_ratio = int(0.63 * n_samples)
    results_train = []
    results_test = []

    for i in range(20):
        # initialize a model based on chosen network_architecture
        network = {
            'UNet': UNet,
            'UNet2': UNet2,
            'FCRN_A': FCRN_A,
        }[network_architecture](input_filters=input_channels,
                                filters=unet_filters,
                                N=convolutions,
                                p=0).to(device)
        network = torch.nn.DataParallel(network)

        # initialize loss, optimized and learning rate scheduler
        loss = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(network.parameters(),
                                    lr=learning_rate,
                                    momentum=0.9,
                                    weight_decay=1e-5)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       step_size=20,
                                                       gamma=0.1)

        ntrain = torch.randperm(n_samples)[:sampling_ratio]

        sampler = torch.utils.data.SubsetRandomSampler(ntrain)

        dataloader['train'] = torch.utils.data.DataLoader(
            dataset['train'], batch_size=batch_size, sampler=sampler)

        dataloadertrain2 = torch.utils.data.DataLoader(dataset['train'],
                                                       batch_size=batch_size)

        dataloader['valid'] = torch.utils.data.DataLoader(dataset['valid'],
                                                          batch_size=1)

        # create training and validation Loopers to handle a single epoch
        train_looper = Looper(network, device, loss,
                              optimizer, dataloader['train'],
                              len(dataset['train']), plots[0], False)

        train_looper.LOG = True

        # current best results (lowest mean absolute error on validation set)
        current_best = np.infty

        for epoch in range(epochs):
            print(f"Epoch {epoch + 1}\n")

            # run training epoch and update learning rate
            result = train_looper.run()
            lr_scheduler.step()

            # update checkpoint if new best is reached
            if result < current_best:
                current_best = result
                torch.save(
                    network.state_dict(),
                    os.path.join(
                        dirr,
                        f'{dataset_name}_boot_i={i}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.pth'
                    ))
                hist = []
                #hist.append(valid_looper.history[-1])
                hist.append(train_looper.history[-1])
                #hist = np.array(hist)
                #print(hist)
                np.savetxt(os.path.join(
                    dirr,
                    f'hist_best_boot_{dataset_name}_{network_architecture}_i={i}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.csv'
                ),
                           hist,
                           delimiter=',')

                print(f"\nNew best result: {result}")

            print("\n", "-" * 80, "\n", sep='')

            if plot:
                fig.savefig(
                    os.path.join(
                        dirr,
                        f'status_boot_i={i}_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.png'
                    ))

        network.load_state_dict(
            torch.load(
                os.path.join(
                    dirr,
                    f'{dataset_name}_boot_i={i}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.pth'
                )))

        valid_looper = Looper(network,
                              device,
                              loss,
                              optimizer,
                              dataloader['valid'],
                              len(dataset['valid']),
                              None,
                              False,
                              validation=True)

        train_looper2 = Looper(network,
                               device,
                               loss,
                               optimizer,
                               dataloadertrain2,
                               len(dataloadertrain2),
                               None,
                               False,
                               validation=True)

        valid_looper.LOG = False
        train_looper2.LOG = False
        valid_looper.MC = False

        with torch.no_grad():
            valid_looper.run()
            train_looper2.run()

        if i == 0:
            results_train.append(train_looper2.true_values)
            results_test.append(valid_looper.true_values)

        results_train.append(train_looper2.predicted_values)
        results_test.append(valid_looper.predicted_values)

        print(f"[Training done] Best result: {current_best}")

        hist = np.array(train_looper.history)
        np.savetxt(os.path.join(
            dirr,
            f'hist_train_boot_i={i}_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.csv'
        ),
                   hist,
                   delimiter=',')
        hist = np.array(valid_looper.history)
        np.savetxt(os.path.join(
            dirr,
            f'hist_test_boot_i={i}_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.csv'
        ),
                   hist,
                   delimiter=',')

    results_train = np.array(results_train)
    results_train = results_train.transpose()
    np.savetxt(os.path.join(
        dirr,
        f'predicted_train_best_boot_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.csv'
    ),
               results_train,
               delimiter=',')

    results_test = np.array(results_test)
    results_test = results_test.transpose()
    np.savetxt(os.path.join(
        dirr,
        f'predicted_test_best_boot_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}.csv'
    ),
               results_test,
               delimiter=',')
示例#2
0
def train(dataset_name: str, network_architecture: str, learning_rate: float,
          weight_decay: float, epochs: int, batch_size: int,
          horizontal_flip: float, vertical_flip: float, unet_filters: int,
          convolutions: int, dropout_prob: float, plot: bool):
    """Train chosen model on selected dataset."""
    # use GPU if avilable
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    dataset = {}  # training and validation HDF5-based datasets
    dataloader = {}  # training and validation dataloaders

    for mode in ['train', 'valid']:
        # expected HDF5 files in dataset_name/(train | valid).h5
        data_path = os.path.join(dataset_name, f"{mode}.h5")
        # turn on flips only for training dataset
        dataset[mode] = H5Dataset(data_path,
                                  horizontal_flip if mode == 'train' else 0,
                                  vertical_flip if mode == 'train' else 0)
        dataloader[mode] = torch.utils.data.DataLoader(dataset[mode],
                                                       batch_size=batch_size)

    # only UCSD dataset provides greyscale images instead of RGB
    input_channels = 1 if dataset_name == 'ucsd' else 3

    # initialize a model based on chosen network_architecture
    network = {
        'UNet': UNet,
        'UNet2': UNet2,
        'UNet2_MC': UNet2_MC,
        'UNet_MC': UNet_MC,
        'FCRN_A': FCRN_A,
        'FCRN_A_MC': FCRN_A_MC
    }[network_architecture](input_filters=input_channels,
                            filters=unet_filters,
                            N=convolutions,
                            p=dropout_prob).to(device)
    network = torch.nn.DataParallel(network)

    # initialize loss, optimized and learning rate scheduler
    loss = torch.nn.MSELoss()
    #loss = torch.nn.L1Loss()
    optimizer = torch.optim.SGD(network.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                weight_decay=weight_decay)  #1e-5
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=20,
                                                   gamma=0.1)
    # prob
    epochs__ = 'epochs=' + str(epochs)
    batch_size__ = 'batch=' + str(batch_size)
    horizontal_flip__ = 'hf=' + str(horizontal_flip)
    vertical_flip__ = 'vf=' + str(vertical_flip)
    unet_filters__ = 'uf=' + str(unet_filters)
    convolutions__ = "conv" + str(convolutions)
    prob_label = 'p=' + str(dropout_prob)
    weightdecay__ = "wd=" + str(weight_decay)

    # if plot flag is on, create a live plot (to be updated by Looper)
    if plot:
        pyplot.ion()
        fig, plots = pyplot.subplots(nrows=2, ncols=2)
    else:
        plots = [None] * 2

    # create training and validation Loopers to handle a single epoch
    train_looper = Looper(network, device,
                          loss, optimizer, dataloader['train'],
                          len(dataset['train']), plots[0], False)

    valid_looper = Looper(network,
                          device,
                          loss,
                          optimizer,
                          dataloader['valid'],
                          len(dataset['valid']),
                          plots[1],
                          False,
                          validation=True)

    train_looper.LOG = True
    valid_looper.LOG = False

    # current best results (lowest mean absolute error on validation set)
    current_best = np.infty

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}\n")

        # run training epoch and update learning rate
        train_looper.run()
        lr_scheduler.step()

        # run validation epoch
        with torch.no_grad():
            result = valid_looper.run()

        # update checkpoint if new best is reached
        if result < current_best:
            current_best = result
            torch.save(
                network.state_dict(),
                f'{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}_{prob_label}_{weightdecay__}.pth'
            )
            hist = []
            hist.append(valid_looper.history[-1])
            hist.append(train_looper.history[-1])
            #hist = np.array(hist)
            #print(hist)
            np.savetxt(
                f'hist_best_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}_{prob_label}_{weightdecay__}.csv',
                hist,
                delimiter=',')

            print(f"\nNew best result: {result}")

        print("\n", "-" * 80, "\n", sep='')

        if plot:
            fig.savefig(
                f'status_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}_{prob_label}_{weightdecay__}.png'
            )

    print(f"[Training done] Best result: {current_best}")

    hist = np.array(train_looper.history)
    np.savetxt(
        f'hist_train_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}_{prob_label}_{weightdecay__}.csv',
        hist,
        delimiter=',')
    hist = np.array(valid_looper.history)
    np.savetxt(
        f'hist_test_{dataset_name}_{network_architecture}_{epochs__}_{batch_size__}_{horizontal_flip__}_{vertical_flip__}_{unet_filters__}_{convolutions__}_{prob_label}_{weightdecay__}.csv',
        hist,
        delimiter=',')

    train_looper.plots = None
    train_looper.validation = True
    train_looper.LOG = False
    train_looper.MC = True

    valid_looper.plots = None
    valid_looper.validation = True
    valid_looper.LOG = False
    valid_looper.MC = True

    NETname = network_architecture + '_' + prob_label + '_' + weightdecay__
    DATAname = dataset_name + '_train_'
    train_looper.MCdropOut(100, NETname, DATAname)
    DATAname = dataset_name + '_test_'
    valid_looper.MCdropOut(100, NETname, DATAname)