def build_graph(self): super(OneStageTrainer, self).build_graph() self.archi = self.archi2 with tf.device('/gpu:1'): variable_tag = '1stage_%s' % self.archi if self.archi == 'rdn': net = res_dense_net(self.plchdr_lr, factor=config.factor, reuse=False, bn=using_batch_norm, name=variable_tag) net_test = res_dense_net(self.plchdr_lr, factor=config.factor, reuse=True, bn=using_batch_norm, name=variable_tag) elif self.archi == 'unet': self.plchdr_lr = tf.placeholder("float", [batch_size] + hr_size, name="LR") # net = unet3d(self.plchdr_lr, upscale=False, reuse=False, name=variable_tag) # net_test = unet3d(self.plchdr_lr, upscale=False, reuse=True, name=variable_tag) net = unet_care(self.plchdr_lr, reuse=False, name=variable_tag) net_test = unet_care(self.plchdr_lr, reuse=True, name=variable_tag) elif self.archi == 'dbpn': net = DBPN(self.plchdr_lr, upscale=True, factor=config.factor, reuse=False, name=variable_tag) net_test = DBPN(self.plchdr_lr, upscale=True, factor=config.factor, reuse=True, name=variable_tag) else: raise Exception('unknow architecture: %s' % self.archi) #net = DBPN(self.plchdr_lr, upscale=True, reuse=False, name=variable_tag) #net_test = DBPN(self.plchdr_lr, upscale=True, reuse=True, name=variable_tag) net.print_params(details=False) self.net = net op_out = tf.identity(net.outputs, self.output_node_name) net_vars = tl.layers.get_variables_with_name(variable_tag, train_only=True, printable=False) ln_loss = loss_fn(self.plchdr_hr, net.outputs) ln_loss_test = loss_fn(self.plchdr_hr, net_test.outputs) ln_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(ln_loss, var_list=net_vars) self.loss.update({'ln_loss' : ln_loss}) self.loss_test.update({'ln_loss_test' : ln_loss_test}) self.optim.update({'ln_optim' : ln_optim}) if using_edge_loss: loss_edges = edges_loss(net.outputs, self.plchdr_hr) e_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_edges, var_list=net_vars) self.loss.update({'edge_loss' : loss_edges}) self.optim.update({'e_optim' : e_optim}) if using_grad_loss: loss_grad = img_gradient_loss(net.outputs, self.plchdr_hr) g_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_grad, var_list=net_vars) self.loss.update({'grad_loss' : loss_grad}) self.optim.update({'g_optim' : g_optim})
def build_graph(self): super(DualStageTrainer, self).build_graph() variable_tag_res = 'Resolve' variable_tag_interp = 'Interp' # if self.archi1 == 'dbpn': # net1 = DBPN # elif self.archi1 == 'denoise' # net1 = denoise_net # else: # _raise(ValueError()) var_tag_n2 = variable_tag_interp self.plchdr_mr = tf.placeholder("float", [batch_size] + lr_size, name="MR") with tf.device('/gpu:%d' % 1): if self.archi1 == 'dbpn': net_stage1 = DBPN(self.plchdr_lr, upscale=False, name=variable_tag_res) net_stage1_test = DBPN(self.plchdr_lr, upscale=False, reuse=True, name=variable_tag_res) elif self.archi1 == 'denoise': net_stage1 = denoise_net(self.plchdr_lr, reuse=False, name=variable_tag_res) net_stage1_test = denoise_net(self.plchdr_lr, reuse=True, name=variable_tag_res) elif self.archi1 == 'unet': net_stage1 = unet3d(self.plchdr_lr, reuse=False, name=variable_tag_res) net_stage1_test = unet3d(self.plchdr_lr, reuse=True, name=variable_tag_res) else: _raise(ValueError()) with tf.device('/gpu:%d' % 2): if self.archi2 == 'rdn': net_stage2 = res_dense_net(net_stage1.outputs, factor=config.factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=True, name=variable_tag_interp) net_stage2_test = res_dense_net(net_stage1_test.outputs, factor=config.factor, conv_kernel=conv_kernel, bn=using_batch_norm, reuse=True, is_train=False, name=variable_tag_interp) else: _raise(ValueError()) self.resolver = net_stage1 self.interpolator = net_stage2 op_out = tf.identity(net_stage2.outputs, self.output_node_name) net_stage1.print_params(details=False) net_stage2.print_params(details=False) #vars_n1 = tl.layers.get_variables_with_name(variable_tag_res, train_only=True, printable=False) vars_n2 = tl.layers.get_variables_with_name(var_tag_n2, train_only=True, printable=False) loss_training_n1 = loss_fn(self.plchdr_mr, net_stage1.outputs) loss_training_n2 = loss_fn(self.plchdr_hr, net_stage2.outputs) loss_test_n1 = loss_fn(self.plchdr_mr, net_stage1_test.outputs) loss_test_n2 = loss_fn(self.plchdr_hr, net_stage2_test.outputs) loss_training = loss_training_n1 + loss_training_n2 loss_test = loss_test_n2 + loss_test_n1 # loss_training = loss_training_n2 # loss_test = loss_test_n2 #n1_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training, var_list=vars_n1) #n2_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training_n2, var_list=vars_n2) #n1_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training_n2) n1_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training_n1) n_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training) if self.pretrain: self.pretrain_op = {} self.pretrain_op.update({'loss_pretrain' : loss_training_n1, 'optim_pretrain' : n1_optim}) self.loss.update({'loss_training' : loss_training, 'loss_training_n2' : loss_training_n2, 'loss_training_n1' : loss_training_n1}) self.loss_test.update({'loss_test' : loss_test, 'loss_test_n2' : loss_test_n2, 'loss_test_n1' : loss_test_n1}) #self.optim.update({'n1_optim' : n1_optim, 'n2_optim' : n2_optim, 'n_optim' : n_optim}) self.optim.update({'n_optim' : n_optim}) if using_edge_loss: loss_edges = edges_loss(net_stage2.outputs, self.plchdr_hr) e_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_edges, var_list=vars_n2) self.loss.update({'edge_loss' : loss_edges}) self.optim.update({'e_optim' : e_optim}) if using_grad_loss: loss_grad = img_gradient_loss(net_stage2.outputs, self.plchdr_hr) g_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_grad, var_list=vars_n2) self.loss.update({'grad_loss' : loss_grad}) self.optim.update({'g_optim' : g_optim})
def build_graph(self): assert batch_size % gpu_num == 0 tower_batch = batch_size // gpu_num with tf.device('/cpu:0'): self.learning_rate_var = tf.Variable(learning_rate_init, trainable=False) optimizer = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1) tower_grads = [] self.plchdr_lr = tf.placeholder("float", [batch_size] + lr_size, name="LR") self.plchdr_hr = tf.placeholder("float", [batch_size] + hr_size, name="HR") if ('2stage' in self.archi): if ('resolve_first' in self.archi): self.plchdr_mr = tf.placeholder("float", [batch_size] + lr_size, name="MR") else: self.plchdr_mr = tf.placeholder("float", [batch_size] + hr_size, name='MR') with tf.variable_scope(tf.get_variable_scope()): for i in range(gpu_num): with tf.device('/gpu:%d' % i): with tf.name_scope('tower_%d' % i) as name_scope: if ('2stage' in self.archi): variable_tag_res = 'Resolve' variable_tag_interp = 'Interp' if ('resolve_first' in self.archi): var_tag_n2 = variable_tag_interp net_stage1 = DBPN(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], upscale=False, name=variable_tag_res) net_stage2 = res_dense_net(net_stage1.outputs, factor=config.factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=True, name=variable_tag_interp) self.resolver = net_stage1 self.interpolator = net_stage2 else : var_tag_n2 = variable_tag_res net_stage1 = res_dense_net(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], factor=config.factor, conv_kernel=conv_kernel, reuse=False, bn=using_batch_norm, is_train=True, name=variable_tag_interp) net_stage2 = DBPN(net_stage1.outputs, upscale=False, name=variable_tag_res) self.resolver = net_stage2 self.interpolator = net_stage1 net_stage1.print_params(details=False) net_stage2.print_params(details=False) #vars_n1 = tl.layers.get_variables_with_name(variable_tag_res, train_only=True, printable=False) vars_n2 = tl.layers.get_variables_with_name(var_tag_n2, train_only=True, printable=False) loss_training_n1 = l2_loss(self.plchdr_mr[i * tower_batch : (i + 1) * tower_batch], net_stage1.outputs) loss_training_n2 = l2_loss(self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch], net_stage2.outputs) loss_training = loss_training_n1 + loss_training_n2 tf.add_to_collection('losses', loss_training) loss_tower = tf.add_n(tf.get_collection('losses', name_scope)) # the total loss for the current tower grads = optimizer.compute_gradients(loss_tower) tower_grads.append(grads) self.loss.update({'loss_training' : loss_training, 'loss_training_n2' : loss_training_n2, 'loss_training_n1' : loss_training_n1}) if using_edge_loss: loss_edges = edges_loss(net_stage2.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch]) e_optim = optimizer.minimize(loss_edges, var_list=vars_n2) self.loss.update({'edge_loss' : loss_edges}) self.optim.update({'e_optim' : e_optim}) if using_grad_loss: loss_grad = img_gradient_loss(net_stage2.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch]) g_optim = optimizer.minimize(loss_grad, var_list=vars_n2) self.loss.update({'grad_loss' : loss_grad}) self.optim.update({'g_optim' : g_optim}) else : variable_tag = '1stage_%s' % self.archi if self.archi is 'rdn': net = res_dense_net(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], factor=config.factor, reuse=i > 0, name=variable_tag) elif self.archi is 'unet': net = unet3d(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], upscale=True, reuse=i > 0, is_train=True, name=variable_tag) elif self.archi is 'dbpn': net = DBPN(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], upscale=True, reuse=i > 0, name=variable_tag) else: raise Exception('unknow architecture: %s' % self.archi) if i == 0: self.net = net ln_loss = l2_loss(self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch], net.outputs) tf.add_to_collection('losses', ln_loss) loss_tower = tf.add_n(tf.get_collection('losses', name_scope)) # the total loss for the current tower grads = optimizer.compute_gradients(loss_tower) tower_grads.append(grads) self.loss.update({'ln_loss' : ln_loss}) ''' if using_edge_loss: loss_edges = edges_loss(net.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch]) e_optim = optimizer.minimize(loss_edges, var_list=net_vars) self.loss.update({'edge_loss' : loss_edges}) self.optim.update({'e_optim' : e_optim}) if using_grad_loss: loss_grad = img_gradient_loss(net.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch]) g_optim = optimizer.minimize(loss_grad, var_list=net_vars) self.loss.update({'grad_loss' : loss_grad}) self.optim.update({'g_optim' : g_optim}) ''' tf.get_variable_scope().reuse_variables() grads = self._average_gradient(tower_grads) n_optim = optimizer.apply_gradients(grads) self.optim.update({'n_optim' : n_optim})
def build_model_and_load_npz(epoch, use_cpu=False, save_pb=False): epoch = 'best' if epoch == 0 else epoch # # search for ckpt files def _search_for_ckpt_npz(file_dir, tags): filelist = os.listdir(checkpoint_dir) for filename in filelist: if '.npz' in filename: if all(tag in filename for tag in tags): return filename return None if (archi1 is not None): resolve_ckpt_file = _search_for_ckpt_npz(checkpoint_dir, ['resolve', str(epoch)]) interp_ckpt_file = _search_for_ckpt_npz(checkpoint_dir, ['interp', str(epoch)]) (resolve_ckpt_file is not None and interp_ckpt_file is not None) or _raise(Exception('checkpoint file not found')) else: #checkpoint_dir = "checkpoint/" #ckpt_file = "brain_conv3_epoch1000_rdn.npz" ckpt_file = _search_for_ckpt_npz(checkpoint_dir, [str(epoch)]) ckpt_file is not None or _raise(Exception('checkpoint file not found')) #====================================== # build the model #====================================== if use_cpu is False: device_str = '/gpu:%d' % device_id else: device_str = '/cpu:0' LR = tf.placeholder(tf.float32, [1] + lr_size) if (archi1 is not None): # if ('resolve_first' in archi): with tf.device(device_str): if archi1 =='dbpn': resolver = DBPN(LR, upscale=False, name="net_s1") elif archi1 =='denoise': resolver = denoise_net(LR, name="net_s1") elif archi1 =='unet': resolver = unet3d(LR, name="net_s1") else: _raise(ValueError()) if archi2 =='rdn': interpolator = res_dense_net(resolver.outputs, factor=factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=False, name="net_s2") net = interpolator else: _raise(ValueError()) else : archi = archi2 with tf.device(device_str): if archi =='rdn': net = res_dense_net(LR, factor=factor, bn=using_batch_norm, conv_kernel=conv_kernel, name="net_s2") elif archi =='unet': # net = unet3d(LR, upscale=False) net = unet_care(LR) elif archi =='dbpn': net = DBPN(LR, upscale=True) else: raise Exception('unknow architecture: %s' % archi) net.print_params(details=False) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) tl.layers.initialize_global_variables(sess) if (archi1 is None): tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + ckpt_file, network=net) else: tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + resolve_ckpt_file, network=resolver) tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + interp_ckpt_file, network=interpolator) return sess, net, LR