예제 #1
0
    def build_graph(self):
        super(OneStageTrainer, self).build_graph()
        self.archi = self.archi2

        with tf.device('/gpu:1'):
            
            variable_tag = '1stage_%s' % self.archi
            if self.archi == 'rdn':
                net = res_dense_net(self.plchdr_lr, factor=config.factor, reuse=False, bn=using_batch_norm, name=variable_tag)
                net_test = res_dense_net(self.plchdr_lr, factor=config.factor, reuse=True, bn=using_batch_norm, name=variable_tag)   

            elif self.archi == 'unet':
                self.plchdr_lr = tf.placeholder("float", [batch_size] + hr_size, name="LR")    
                # net = unet3d(self.plchdr_lr, upscale=False, reuse=False, name=variable_tag)
                # net_test = unet3d(self.plchdr_lr, upscale=False, reuse=True, name=variable_tag)
                net = unet_care(self.plchdr_lr, reuse=False, name=variable_tag)
                net_test = unet_care(self.plchdr_lr, reuse=True, name=variable_tag)
            elif self.archi == 'dbpn':
                net = DBPN(self.plchdr_lr, upscale=True, factor=config.factor, reuse=False, name=variable_tag)
                net_test = DBPN(self.plchdr_lr, upscale=True, factor=config.factor, reuse=True, name=variable_tag)
            else:
                raise Exception('unknow architecture: %s' % self.archi)

            #net = DBPN(self.plchdr_lr, upscale=True, reuse=False, name=variable_tag)

        #net_test = DBPN(self.plchdr_lr, upscale=True, reuse=True, name=variable_tag) 
        
        net.print_params(details=False)
        self.net = net
        op_out   = tf.identity(net.outputs, self.output_node_name)
        net_vars = tl.layers.get_variables_with_name(variable_tag, train_only=True, printable=False)

        ln_loss = loss_fn(self.plchdr_hr, net.outputs)
        ln_loss_test = loss_fn(self.plchdr_hr, net_test.outputs)
        ln_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(ln_loss, var_list=net_vars)

        
        self.loss.update({'ln_loss' : ln_loss})
        self.loss_test.update({'ln_loss_test' : ln_loss_test})
        self.optim.update({'ln_optim' : ln_optim})

        if using_edge_loss:
            loss_edges = edges_loss(net.outputs, self.plchdr_hr)
            e_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_edges, var_list=net_vars)
            self.loss.update({'edge_loss' : loss_edges})
            self.optim.update({'e_optim' : e_optim})
        if using_grad_loss:
            loss_grad = img_gradient_loss(net.outputs, self.plchdr_hr)
            g_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_grad, var_list=net_vars)
            self.loss.update({'grad_loss' : loss_grad})
            self.optim.update({'g_optim' : g_optim})
예제 #2
0
    def build_graph(self):  
        super(DualStageTrainer, self).build_graph()
        variable_tag_res = 'Resolve'
        variable_tag_interp = 'Interp'

        # if self.archi1 == 'dbpn':
        #     net1 = DBPN 
        # elif self.archi1 == 'denoise'
        #     net1 = denoise_net
        # else:
        #     _raise(ValueError())   

        var_tag_n2 = variable_tag_interp
        self.plchdr_mr = tf.placeholder("float", [batch_size] + lr_size, name="MR")  
        with tf.device('/gpu:%d' % 1):
            if self.archi1 == 'dbpn':
                net_stage1      = DBPN(self.plchdr_lr, upscale=False, name=variable_tag_res)
                net_stage1_test = DBPN(self.plchdr_lr, upscale=False, reuse=True, name=variable_tag_res)
            elif self.archi1 == 'denoise':
                net_stage1      = denoise_net(self.plchdr_lr, reuse=False, name=variable_tag_res)
                net_stage1_test = denoise_net(self.plchdr_lr, reuse=True, name=variable_tag_res)
            elif self.archi1 == 'unet':
                net_stage1      = unet3d(self.plchdr_lr, reuse=False, name=variable_tag_res)
                net_stage1_test = unet3d(self.plchdr_lr, reuse=True, name=variable_tag_res)
            else:
                _raise(ValueError())   

        with tf.device('/gpu:%d' % 2):
            if self.archi2 == 'rdn':
                net_stage2      = res_dense_net(net_stage1.outputs, factor=config.factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=True, name=variable_tag_interp)
                net_stage2_test = res_dense_net(net_stage1_test.outputs, factor=config.factor, conv_kernel=conv_kernel, bn=using_batch_norm, reuse=True, is_train=False, name=variable_tag_interp)
            else:
                _raise(ValueError())   

        self.resolver     = net_stage1
        self.interpolator = net_stage2
        op_out            = tf.identity(net_stage2.outputs, self.output_node_name)
            
        net_stage1.print_params(details=False)
        net_stage2.print_params(details=False)

        #vars_n1 = tl.layers.get_variables_with_name(variable_tag_res, train_only=True, printable=False)
        vars_n2 = tl.layers.get_variables_with_name(var_tag_n2, train_only=True, printable=False)
        
        loss_training_n1 = loss_fn(self.plchdr_mr, net_stage1.outputs)
        loss_training_n2 = loss_fn(self.plchdr_hr, net_stage2.outputs)
        
        loss_test_n1 = loss_fn(self.plchdr_mr, net_stage1_test.outputs)
        loss_test_n2 = loss_fn(self.plchdr_hr, net_stage2_test.outputs)

        loss_training = loss_training_n1 + loss_training_n2
        loss_test = loss_test_n2 + loss_test_n1
        # loss_training = loss_training_n2
        # loss_test = loss_test_n2

        #n1_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training, var_list=vars_n1)
        #n2_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training_n2, var_list=vars_n2)
        #n1_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training_n2)
        n1_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training_n1)
        n_optim  = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_training)
        
        if self.pretrain:
            self.pretrain_op = {}
            self.pretrain_op.update({'loss_pretrain' : loss_training_n1, 'optim_pretrain' : n1_optim})

        self.loss.update({'loss_training' : loss_training, 'loss_training_n2' : loss_training_n2, 'loss_training_n1' : loss_training_n1})
        self.loss_test.update({'loss_test' : loss_test, 'loss_test_n2' : loss_test_n2, 'loss_test_n1' : loss_test_n1})
        #self.optim.update({'n1_optim' : n1_optim, 'n2_optim' : n2_optim, 'n_optim' : n_optim})
        self.optim.update({'n_optim' : n_optim})

        if using_edge_loss:
            loss_edges = edges_loss(net_stage2.outputs, self.plchdr_hr)
            e_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_edges, var_list=vars_n2)
            self.loss.update({'edge_loss' : loss_edges})
            self.optim.update({'e_optim' : e_optim})

        if using_grad_loss:
            loss_grad = img_gradient_loss(net_stage2.outputs, self.plchdr_hr)
            g_optim = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1).minimize(loss_grad, var_list=vars_n2)
            self.loss.update({'grad_loss' : loss_grad})
            self.optim.update({'g_optim' : g_optim})
예제 #3
0
    def build_graph(self):
        assert batch_size % gpu_num == 0
        tower_batch = batch_size // gpu_num
        
        with tf.device('/cpu:0'):
            self.learning_rate_var = tf.Variable(learning_rate_init, trainable=False)
            optimizer = tf.train.AdamOptimizer(self.learning_rate_var, beta1=beta1)
            tower_grads = []

            self.plchdr_lr = tf.placeholder("float", [batch_size] + lr_size, name="LR")       
            self.plchdr_hr = tf.placeholder("float", [batch_size] + hr_size, name="HR")
            if ('2stage' in self.archi):
                if ('resolve_first' in self.archi):
                    self.plchdr_mr = tf.placeholder("float", [batch_size] + lr_size, name="MR")  
                else:
                    self.plchdr_mr = tf.placeholder("float", [batch_size] + hr_size, name='MR')  

            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(gpu_num):
                    with tf.device('/gpu:%d' % i):
                        with tf.name_scope('tower_%d' % i) as name_scope:
                            if ('2stage' in self.archi):
                                variable_tag_res = 'Resolve'
                                variable_tag_interp = 'Interp'
                                if ('resolve_first' in self.archi):
                                    var_tag_n2 = variable_tag_interp
                                    net_stage1 = DBPN(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], upscale=False, name=variable_tag_res)
                                    net_stage2 = res_dense_net(net_stage1.outputs, factor=config.factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=True, name=variable_tag_interp)
                                    self.resolver = net_stage1
                                    self.interpolator = net_stage2
                                else :
                                    var_tag_n2 = variable_tag_res
                                    net_stage1 = res_dense_net(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], factor=config.factor, conv_kernel=conv_kernel, reuse=False, bn=using_batch_norm, is_train=True, name=variable_tag_interp)
                                    net_stage2 = DBPN(net_stage1.outputs, upscale=False, name=variable_tag_res)
                                    self.resolver = net_stage2
                                    self.interpolator = net_stage1
                                net_stage1.print_params(details=False)
                                net_stage2.print_params(details=False)

                                #vars_n1 = tl.layers.get_variables_with_name(variable_tag_res, train_only=True, printable=False)
                                vars_n2 = tl.layers.get_variables_with_name(var_tag_n2, train_only=True, printable=False)
                                
                                loss_training_n1 = l2_loss(self.plchdr_mr[i * tower_batch : (i + 1) * tower_batch], net_stage1.outputs)
                                loss_training_n2 = l2_loss(self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch], net_stage2.outputs)
                                
                                loss_training = loss_training_n1 + loss_training_n2
                                tf.add_to_collection('losses', loss_training)
                                loss_tower = tf.add_n(tf.get_collection('losses', name_scope)) # the total loss for the current tower

                                grads = optimizer.compute_gradients(loss_tower)
                                tower_grads.append(grads)

                                self.loss.update({'loss_training' : loss_training, 'loss_training_n2' : loss_training_n2, 'loss_training_n1' : loss_training_n1})
                                

                                if using_edge_loss:
                                    loss_edges = edges_loss(net_stage2.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch])
                                    e_optim = optimizer.minimize(loss_edges, var_list=vars_n2)
                                    self.loss.update({'edge_loss' : loss_edges})
                                    self.optim.update({'e_optim' : e_optim})

                                if using_grad_loss:
                                    loss_grad = img_gradient_loss(net_stage2.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch])
                                    g_optim = optimizer.minimize(loss_grad, var_list=vars_n2)
                                    self.loss.update({'grad_loss' : loss_grad})
                                    self.optim.update({'g_optim' : g_optim})

                            else : 
                                variable_tag = '1stage_%s' % self.archi
                                if self.archi is 'rdn':
                                    net = res_dense_net(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], factor=config.factor, reuse=i > 0, name=variable_tag)
                                elif self.archi is 'unet':
                                    net = unet3d(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], upscale=True, reuse=i > 0, is_train=True, name=variable_tag)
                                elif self.archi is 'dbpn':
                                    net = DBPN(self.plchdr_lr[i * tower_batch : (i + 1) * tower_batch], upscale=True, reuse=i > 0, name=variable_tag)
                                else:
                                     raise Exception('unknow architecture: %s' % self.archi)

                                
                                if i == 0:
                                    self.net = net
                                    
                                ln_loss = l2_loss(self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch], net.outputs)
                                tf.add_to_collection('losses', ln_loss)
                                loss_tower = tf.add_n(tf.get_collection('losses', name_scope)) # the total loss for the current tower

                                grads = optimizer.compute_gradients(loss_tower)
                                tower_grads.append(grads)
                                
                                self.loss.update({'ln_loss' : ln_loss})

                                '''
                                if using_edge_loss:
                                    loss_edges = edges_loss(net.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch])
                                    e_optim = optimizer.minimize(loss_edges, var_list=net_vars)
                                    self.loss.update({'edge_loss' : loss_edges})
                                    self.optim.update({'e_optim' : e_optim})
                                if using_grad_loss:
                                    loss_grad = img_gradient_loss(net.outputs, self.plchdr_hr[i * tower_batch : (i + 1) * tower_batch])
                                    g_optim = optimizer.minimize(loss_grad, var_list=net_vars)
                                    self.loss.update({'grad_loss' : loss_grad})
                                    self.optim.update({'g_optim' : g_optim})
                                '''

                            tf.get_variable_scope().reuse_variables()

            grads = self._average_gradient(tower_grads)
            n_optim = optimizer.apply_gradients(grads)
            self.optim.update({'n_optim' : n_optim})    
예제 #4
0
파일: eval.py 프로젝트: xinDW/DVSR
def build_model_and_load_npz(epoch, use_cpu=False, save_pb=False):
    
    epoch = 'best' if epoch == 0 else epoch
    # # search for ckpt files 
    def _search_for_ckpt_npz(file_dir, tags):
        filelist = os.listdir(checkpoint_dir)
        for filename in filelist:
            if '.npz' in filename:
                if all(tag in filename for tag in tags):
                    return filename
        return None

    if (archi1 is not None):
        resolve_ckpt_file = _search_for_ckpt_npz(checkpoint_dir, ['resolve', str(epoch)])
        interp_ckpt_file  = _search_for_ckpt_npz(checkpoint_dir, ['interp', str(epoch)])
       
        (resolve_ckpt_file is not None and interp_ckpt_file is not None) or _raise(Exception('checkpoint file not found'))

    else:
        #checkpoint_dir = "checkpoint/" 
        #ckpt_file = "brain_conv3_epoch1000_rdn.npz"
        ckpt_file = _search_for_ckpt_npz(checkpoint_dir, [str(epoch)])
        
        ckpt_file is not None or _raise(Exception('checkpoint file not found'))
    

    #======================================
    # build the model
    #======================================
    
    if use_cpu is False:
        device_str = '/gpu:%d' % device_id
    else:
        device_str = '/cpu:0'

    LR = tf.placeholder(tf.float32, [1] + lr_size)
    if (archi1 is not None):
        # if ('resolve_first' in archi):        
        with tf.device(device_str):
            if archi1 =='dbpn':   
                resolver = DBPN(LR, upscale=False, name="net_s1")
            elif archi1 =='denoise': 
                resolver = denoise_net(LR, name="net_s1")
            elif archi1 =='unet': 
                resolver = unet3d(LR, name="net_s1")
            else:
                _raise(ValueError())
            
            if archi2 =='rdn':
                interpolator = res_dense_net(resolver.outputs, factor=factor, conv_kernel=conv_kernel, bn=using_batch_norm, is_train=False, name="net_s2")
                net = interpolator
            else:
                _raise(ValueError())

    else : 
        archi = archi2
        with tf.device(device_str):
            if archi =='rdn':
                net = res_dense_net(LR, factor=factor, bn=using_batch_norm, conv_kernel=conv_kernel, name="net_s2")
            elif archi =='unet':
                # net = unet3d(LR, upscale=False)
                net = unet_care(LR)
            elif archi =='dbpn':
                net = DBPN(LR, upscale=True)
            else:
                raise Exception('unknow architecture: %s' % archi)

    net.print_params(details=False)
    
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    tl.layers.initialize_global_variables(sess)
    if (archi1 is None):
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + ckpt_file, network=net)
    else:
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + resolve_ckpt_file, network=resolver)
        tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/' + interp_ckpt_file, network=interpolator)

    return sess, net, LR