def dataset(slice_axis): return slice_dataset(sub_dataset='Set_A', subtract_mean=True, split='valid', slices=1, slice_axis = 0, data_config= 'conf/cremi_datasets.toml')
def build_network_and_dataset(net_conf, label_conf, slice_axis): data_set = slice_dataset(sub_dataset=self.dataset_conf['sub_dataset'], subtract_mean=True, split=split, slices=net_conf['slices'], slice_axis = slice_axis, data_config=data_config) data_out_labels = data_set.output_labels() input_lbCHs_cat_for_net2 = label_conf['label_catin_net2'] net_model = NETWORKS[net_conf['model']] if net_conf['model'] == 'M2DUnet_withDilatConv': network = net_model(freeze_net1=True, target_label=data_out_labels, label_catin_net2=input_lbCHs_cat_for_net2, in_ch=net_conf['slices']) print(net_model) else: network = net_model(target_label=data_out_labels, in_ch=net_conf['slices']) pre_trained_file = net_conf['trained_file'] print('load weights from {}'.format(pre_trained_file)) network.load_state_dict(torch.load(pre_trained_file)) return network, data_set
def get_data(set_name='Set_C'): dataset = slice_dataset(sub_dataset='All', subtract_mean=True, split='valid') dataset.set_current_subDataset(set_name) data=dataset.get_data() label=dataset.get_label() return data, label
def get_rawim_or_labels(task_set, subset_name, data_set): data_config ='conf/cremi_datasets.toml' \ if task_set == 'valid' else \ 'conf/cremi_datasets_test.toml' orig_dataset = slice_dataset(sub_dataset=subset_name, data_config=data_config) if data_set == 'image': data = orig_dataset.get_data() elif data_set == 'label': data = orig_dataset.get_label() return data
def __init__(self, config_file): self.parse_toml(config_file) NETWORKS = \ {'Unet': Unet, 'DUnet': DUnet, 'MDUnet': MdecoderUnet, 'MDUnetDilat': MdecoderUnet_withDilatConv, 'M2DUnet': Mdecoder2Unet, 'M2DUnet_withDilatConv': Mdecoder2Unet_withDilatConv, 'MDUnet_FullDilat':MdecoderUnet_withFullDilatConv} if self.dataset_conf['dataset'] == 'valid': data_config = 'conf/cremi_datasets.toml' split = 'valid' elif self.dataset_conf['dataset'] == 'predict': data_config = 'conf/cremi_datasets_test.toml' split = 'predict' else: data_config = 'conf/cremi_datasets.toml' split = 'valid' '''' create dataset which is able to iteratively obtain slices of image (3,5 slice with stride 1) from either z-direction ir xy-direction ''' self.dataset = slice_dataset(sub_dataset=self.dataset_conf['sub_dataset'], subtract_mean=True, split=split, slices=self.net_conf['z_slices'], data_config=data_config) #data_config='conf/cremi_datasets_with_tflabels.toml') data_out_labels = self.dataset.output_labels() input_lbCHs_cat_for_net2 = self.label_conf['label_catin_net2'] net1_out_put_label = self.label_conf['labels'] net1_target_label_ch_dict={} for lb,ch in data_out_labels.iteritems(): if lb in net1_out_put_label: net1_target_label_ch_dict[lb] =ch if 'final_labels' in self.label_conf: net2_out_put_label=self.label_conf['final_labels'] net2_target_label_ch_dict= {} for lb,ch in data_out_labels.iteritems(): if lb in net2_out_put_label: net2_target_label_ch_dict[lb] =ch elif 'final_label' in self.label_conf: net2_target_label_ch_dict= {} net2_target_label_ch_dict['final']=data_out_labels[self.label_conf['final_label']] #target_label_ch_dict ={ lb:ch if lb in network_out_put_label for lb,ch in data_out_labels} print(data_out_labels) label_ch_pair_info ={'gradient':2,'sizemap':1,'affinity':1,'centermap':2,'distance':1,'skeleton':1} # if 'sub_net' in self.conf: # net_1_ch_pair = {} # for lb in self.label_conf['labels']: # net_1_ch_pair[lb] = label_ch_pair_info[lb] # subnet_model = NETWORKS[self.conf['sub_net']['model']] # self.sub_network = subnet_model(target_label=net_1_ch_pair, in_ch=self.net_conf['z_slices']) # self.network = net_model(self.sub_network, # freeze_net1=freeze_net1, # target_label=net_1_ch_pair, # net2_target_label= label_ch_pair, # label_catin_net2=input_lbCHs_cat_for_net2, # in_ch=in_ch, # out_ch=out_ch, # first_out_ch=16) # create network and load the weights net_model = NETWORKS[self.net_conf['model']] if self.net_conf['model'] == 'M2DUnet_withDilatConv': self.network = net_model(freeze_net1=True, target_label=net1_target_label_ch_dict, net2_target_label= net2_target_label_ch_dict, label_catin_net2=input_lbCHs_cat_for_net2, in_ch=self.net_conf['z_slices']) print(net_model) else: print(self.net_conf['model']) self.network = net_model(target_label=net1_target_label_ch_dict, in_ch=self.net_conf['z_slices'],BatchNorm_final=False) #self.network = net_model(target_label=data_out_labels, in_ch=self.net_conf['z_slices']) self.use_gpu =True if self.use_gpu and torch.cuda.is_available(): print ('model_set_cuda') self.network = self.network.cuda() pre_trained_file = self.net_conf['trained_file'] print('load weights from {}'.format(pre_trained_file)) self.network.load_state_dict(torch.load(pre_trained_file))