def rand_bool(): return bool(np.random.randint(0, 2)) #Test the batch-functions if __name__ == '__main__': from data import readSEGY, readLabels, get_slice import tensorboard import numpy as np data, data_info = readSEGY('F3/data.segy') train_coordinates = {'1': np.expand_dims(np.array([50, 50, 50]), 1)} logger = tensorboard.TBLogger('log', 'batch test') [batch, labels] = get_random_batch(data, train_coordinates, 65, 32) logger.log_images('normal', batch) [batch, labels] = get_random_batch(data, train_coordinates, 65, 32, random_flip=True) logger.log_images('flipping', batch) [batch, labels] = get_random_batch(data, train_coordinates, 65, 32,
import tensorboard import numpy as np from utils import * #This is the network definition proposed in the paper #Parameters dataset_name = 'F3' im_size = 65 batch_size = 32 #If you have a GPU with little memory, try reducing this to 16 (may degrade results) use_gpu = True #Switch to toggle the use of GPU or not log_tensorboard = True #Log progress on tensor board if log_tensorboard: logger = tensorboard.TBLogger('log','Train') #See the texture_net.py file for the network configuration from texture_net import TextureNet network = TextureNet() #Loss function cross_entropy = nn.CrossEntropyLoss() #Softmax function is included #Optimizer to control step size in gradient descent optimizer = torch.optim.Adam(network.parameters()) #Transfer model to gpu if use_gpu: network = network.cuda()
#Load trained model (run train.py to create trained network = TextureNet() network.load_state_dict(torch.load('F3/saved_model.pt')) if use_gpu: network = network.cuda() network.eval() # We can set the interpretation resolution to save time. # The interpretation is then conducted over every n-th sample and # then resized to the full size of the input data resolution = 16 ########################################################################## slice = 'inline' #Inline, crossline, timeslice or full slice_no = 339 #Log to tensorboard logger = tensorboard.TBLogger('log', 'Test') logger.log_images(slice + '_' + str(slice_no), get_slice(data, data_info, slice, slice_no), cm='gray') """ Plot extracted features, class probabilities and salt-predictions for slice """ #features (attributes) from layer 5 im = interpret(network.f5, data, data_info, slice, slice_no, im_size, resolution) logger.log_images(slice + '_' + str(slice_no) + ' _f5', im) #features from layer 4 im = interpret(network.f4, data, data_info, slice, slice_no, im_size, resolution) logger.log_images(slice + '_' + str(slice_no) + ' _f4', im) #Class "probabilities"