예제 #1
0
    _, _, test_part = get_indices()

    # Load the trained model
    load_configuration_classifier = 'Classifier-FT-15e'

    # Initialize the network
    encoder = DenoisingNet(verbose=True)
    encoder_state_dict = torch.load('data/' + load_configuration_classifier +
                                    '_net_parameters.torch',
                                    map_location='cpu')
    # Set the loaded parameters to the network
    encoder.load_state_dict(encoder_state_dict['encoder_tuned'])

    # Load the dataset
    dataset = Waterfalls(fpath='../../DATASETS/Waterfalls/Waterfalls_fish.mat',
                         transform=utils.NormalizeSignal((1200, 20)))

    test_part = utils.filter_indices(test_part,
                                     parameters=dataset[:],
                                     min_snr=2)

    # Load test data efficiently
    test_samples = DataLoader(dataset,
                              batch_size=32,
                              shuffle=False,
                              sampler=SubsetRandomSampler(test_part),
                              collate_fn=utils.waterfalls_collate)

    classes, config = utils.get_classes()
    classifier = FishClassifier(config,
                                estimate_parameters=False,
예제 #2
0
##########################
# TODO Set all the following parameters by args. Improve this section
# Parameters
verbose_flag = True
plot_flag = False
lr = 1e-3  # Learning rate
batchsize = 32
# max_sample = 60000
train_blocks = [1, 2, 3, 4, 5, 6]
block_epochs = 5
final_train = True
open_config = None
##########################

print('[Loading data...]')
dataset = Waterfalls(transform=utils.NormalizeSignal(WATERFALLS_SIZE))
# parameters = dataset[0:max_sample]  # Get all the parameters
# snr_cond = parameters[0] >= min_SNR  # Check where data satisfies the SNR condition
# dataset_indices = np.nonzero(snr_cond)[0]  # Get indices of the valid part of the dataset

# Write a log file containing useful information with in 'logs/$reference_time$.log'
reference_time = time.strftime("%H_%M_%S")
log_name = 'logs/' + reference_time + '.log'
with open(log_name, 'w') as logfile:  # Append mode
    logfile.write('LOG FILE - ' + reference_time + '\n\tCUDA Support\t' +
                  torch.cuda.is_available().__str__())

# Load train data efficiently
train_dataloader = DataLoader(
    dataset,
    batch_size=batchsize,
예제 #3
0
    mode = 'a'
    load_configuration = 'DenoisingNet-5e'
    configuration = torch.load('data/' + load_configuration +
                               '_net_parameters.torch',
                               map_location='cpu')

    # Initialize the network
    net = DenoisingNet(verbose=True)

    # Set the loaded parameters to the network
    net.load_state_dict(configuration['DenoisingNet'])

    # Load the dataset
    if mode == 'tracking':
        transform = transforms.Compose([
            utils.NormalizeSignal(WATERFALLS_SIZE),
            utils.Paths2D(WATERFALLS_SIZE)
        ])
    else:
        transform = utils.NormalizeSignal((1200, 20))

    dataset = Waterfalls(fpath='../../DATASETS/Waterfalls/Waterfalls_fish.mat',
                         transform=transform)

    # %% Evaluate the accuracy metrics
    test_samples = DataLoader(dataset,
                              batch_size=32,
                              shuffle=False,
                              sampler=SubsetRandomSampler(test_part),
                              collate_fn=utils.waterfalls_collate)
예제 #4
0
    # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in ae.state_dict()}
    # # 2. overwrite entries in the existing state dict
    # ae.state_dict().update(pretrained_dict)
    # # 3. load the new state dict
    # ae.load_state_dict(pretrained_dict)
    #
    # Inizialize the tracker
    tracer = FishTracker()
    pretrained_dict = torch.load('data/' + load_configuration_tracker +
                                 '_net_parameters.torch',
                                 map_location='cpu')
    tracer.load_state_dict(pretrained_dict['tracer'])

    dataset = Waterfalls(fpath='../../DATASETS/Waterfalls/Waterfalls_fish.mat',
                         transform=transforms.Compose([
                             utils.NormalizeSignal((1200, 20)),
                             utils.Paths2D((1200, 20))
                         ]))

    # %% Evaluate the accuracy metrics
    test_samples = DataLoader(dataset,
                              batch_size=32,
                              shuffle=False,
                              sampler=SubsetRandomSampler(test_part),
                              collate_fn=utils.waterfalls_collate)

    # %% View the reconstruction performance with random test sample

    # Select randomly one sample
    idx = test_part[np.random.randint(len(test_part))]