示例#1
0
文件: test.py 项目: vghost2008/wml
 def test_dropblock(self):
     with self.test_session() as sess:
         data = tf.ones(shape=[1,32,32,1],dtype=tf.float32)
         data =wnnl.dropblock(data,keep_prob=0.8,block_size=4,is_training=True)
         data = tf.reshape(data,shape=[32,32])
         data = tf.cast(data,tf.int32)
         wmlu.show_list(sess.run(data).tolist())
示例#2
0
def get_data(data_dir,
             batch_size=4,
             num_samples=1,
             num_classes=3,
             id_to_label={}):
    dataset = get_database(dataset_dir=data_dir,
                           num_classes=num_classes,
                           num_samples=num_samples)
    files = wmlu.recurse_get_filepath_in_dir(data_dir, suffix=".tfrecord")
    print("tfrecords:")
    wmlu.show_list(files)
    with tf.name_scope('data_provider'):
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=2,
            common_queue_capacity=20 * batch_size,
            common_queue_min=10 * batch_size,
            seed=int(time.time()),
            shuffle=True)
        [image, glabels,
         bboxes] = provider.get(["image", "object/label", "object/bbox"])
    label = 1
    if len(id_to_label) == 0:
        for i in range(1, 80):
            id_to_label[i] = label
            label += 1
    table = tf.contrib.lookup.HashTable(
        tf.contrib.lookup.KeyValueTensorInitializer(
            np.array(list(id_to_label.keys()), dtype=np.int64),
            np.array(list(id_to_label.values()), dtype=np.int64)), -1)
    glabels = table.lookup(glabels)
    wmlt.add_to_hash_table_collection(table.init)

    return image, glabels, bboxes
    def multi_thread_to_tfrecords_by_files(self,files, output_dir,shuffling=False,fidx=0):
        wmlu.create_empty_dir(output_dir,remove_if_exists=True,yes_to_all=True)
        if shuffling:
            random.seed(time.time())
            random.shuffle(files)
        wmlu.show_list(files[:100])
        if len(files)>100:
            print("...")
        print(f"Total {len(files)} files.")
        sys.stdout.flush()
        files = wmlu.list_to_2dlist(files,SAMPLES_PER_FILES)
        files_data = list(enumerate(files))
        if fidx != 0:
            _files_data = []
            for fid,file_d in files_data:
                _files_data.append([fid+fidx,file_d])
            files_data = _files_data
        sys.stdout.flush()
        pool = Pool(13)
        pool.map(functools.partial(self.make_tfrecord,output_dir=output_dir),files_data)
        #list(map(functools.partial(self.make_tfrecord,output_dir=output_dir),files_data))
        pool.close()
        pool.join()

        print('\nFinished converting the dataset total %d examples.!'%(len(files)))
示例#4
0
    def build_inference_net(self):
        try:
            if not os.path.exists(self.log_dir):
                wmlu.create_empty_dir(self.log_dir)
            if not os.path.exists(self.ckpt_dir):
                wmlu.create_empty_dir(self.ckpt_dir)
        except:
            pass
        '''
        When inference, self.data is just a tensor
        '''
        data = {IMAGE: self.data}
        DataLoader.detection_image_summary(data, name="data_source")
        self.input_data = data
        '''if self.cfg.GLOBAL.DEBUG:
            data[IMAGE] = tf.Print(data[IMAGE],[tf.shape(data[IMAGE]),data[ORG_HEIGHT],data[ORG_WIDTH],data[HEIGHT],data[WIDTH]],summarize=100,
                                   name="XXXXX")'''
        self.res_data, loss_dict = self.model.forward(data)
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self.top_variable_name_scope = "Model"

        print("batch_norm_ops.")
        wmlu.show_list(
            [x.name for x in tf.get_collection(tf.GraphKeys.UPDATE_OPS)])
示例#5
0
文件: dist.py 项目: vghost2008/wml
def all_reduce_norm(module):
    """
    All reduce norm statistics in different devices.
    """
    states = get_async_norm_states(module)
    print("Reduce keys:")
    wmlu.show_list(list(states.keys()))
    states = all_reduce(states, op="mean")
    module.load_state_dict(states, strict=False)
示例#6
0
文件: test.py 项目: vghost2008/wml
 def test_orthogonal_regularizerv1(self):
     with self.test_session() as sess:
         print("test_orthogonal_regularizerv1")
         fn = wnnl.orthogonal_regularizer(1)
         weight = np.array(list(range(3)),np.float32)
         wmlu.show_list(weight)
         v = fn(weight)
         v = sess.run(v)
         self.assertAllClose(v,9,atol=1e-4)
def multithread_create_tf_record(data_dir, output_dir, img_suffix="jpg", name="train", shuffling=True, fidx=0,
                                 label_text_to_id=None):
    files = get_files(data_dir, img_suffix=img_suffix)
    if os.path.exists(output_dir) and (data_dir != output_dir):
        shutil.rmtree(output_dir)
        print("删除文件夹%s" % (output_dir))
    wmlu.show_list(files)
    print(f"Find {len(files)} files.")
    return multithread_create_tf_record_by_files(files, output_dir,
                                                 name, shuffling, fidx,
                                                 label_text_to_id)
示例#8
0
def get_database(dataset_dir, num_parallel=1, file_pattern='*.record'):

    file_pattern = os.path.join(dataset_dir, file_pattern)
    files = glob.glob(file_pattern)
    if len(files) == 0:
        logging.error(f'No files found in {file_pattern}')
    else:
        print(f"Total {len(files)} files.")
        wmlu.show_list(files)
    dataset = tf.data.TFRecordDataset(files, num_parallel_reads=num_parallel)
    dataset = dataset.map(__parse_func, num_parallel_calls=num_parallel)
    return dataset
示例#9
0
    def forward(self,inputs):
        if self.tmp_path.startswith("/tmp"):
            wmlu.create_empty_dir(self.tmp_path,remove_if_exists=True,yes_to_all=True)
        print("inputs shape:",inputs.shape)
        raw_path = os.path.join(self.tmp_path,"input.raw")
        input_list = os.path.join(self.tmp_path,"file_list.txt")
        output_dir = os.path.join(self.tmp_path,"output")
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        with open(input_list, "w") as f:
            if self.output_layers is not None:
                v = f"#{self.output_layers[0]}"
                for x in self.output_layers[1:]:
                    v += f" {x}"
                v += "\n"
                f.write(v)
            f.write(raw_path)

        self.to_snpe_raw(inputs,raw_path)

        cmd = "{}  --container {} --input_list {} --output_dir {}".format(self.bin, self.model_path,
                                                                          input_list,
                                                                          output_dir)
        print(f"CMD:{cmd}")
        print(f"All output files.")
        os.system(cmd)
        all_files = wmlu.recurse_get_filepath_in_dir(output_dir,suffix=".raw")
        wmlu.show_list(all_files)
        print("-------------------------------")
        res_data = []
        output_dir = "/home/wj/0day/output" #for DEBUG
        if self.output_names is not None:
            for name,shape in zip(self.output_names,self.output_shapes):
                path = os.path.join(output_dir,"Result_0",name+self.output_suffix)
                if not os.path.exists(path):
                    print(f"{path} not exits")
                    res_data.append(None)
                else:
                    print(f"Read from {path}")
                    td = np.fromfile(path,dtype=np.float32)
                    if shape is not None:
                        td = np.reshape(td,shape)
                    res_data.append(td)
        else:
            path = all_files[0]
            print(f"Use ")
            td = np.fromfile(path, dtype=np.float32)
            res_data.append(td)

        return res_data
示例#10
0
def get_database(dataset_dir, num_parallel=1, file_pattern='*_train.record'):

    file_pattern = os.path.join(dataset_dir, file_pattern)
    files = glob.glob(file_pattern)
    if len(files) == 0:
        logging.error(f'No files found in {file_pattern}')
        a = input(f"No files found in {file_pattern}, continue?(y/n)")
        if a != 'y':
            exit(-1)
    else:
        wmlu.show_list(files)
    dataset = tf.data.TFRecordDataset(files, num_parallel_reads=num_parallel)
    dataset = dataset.map(__parse_func, num_parallel_calls=num_parallel)
    return dataset
示例#11
0
文件: wnn.py 项目: vghost2008/wml
def get_regularization_losses(scopes=None,re_pattern=None,reduction="sum"):
    with tf.name_scope("regularization_losses"):
        col = get_variables_of_collection(tf.GraphKeys.REGULARIZATION_LOSSES,scopes=scopes,re_pattern=re_pattern)
        if len(col)>0:
            print("wr_loss")
            wmlu.show_list(col)
            print("end wr_loss")
            if reduction == "mean":
                return tf.reduce_mean(col)
            elif reduction == "sum":
                return tf.reduce_sum(col)
            else:
                return col
        else:
            return None
示例#12
0
文件: test.py 项目: vghost2008/wml
 def test_channel_upsampling(self):
     with self.test_session() as sess:
         print("Test channel upsample")
         data_in = list(range(16))
         data_in = np.array(data_in)
         data_in = np.reshape(data_in,[1,2,2,4])
         expected_data = np.array([[0,1,4,5],[2,3,6,7],[8,9,12,13],[10,11,14,15]])
         wmlu.show_list(data_in)
         data_in = tf.constant(data_in,dtype=tf.float32)
         data_in = wmlt.channel_upsample(data_in,scale=2)
         data_in = tf.squeeze(data_in,axis=-1)
         data_in = tf.squeeze(data_in,axis=0)
         data_out = sess.run(data_in)
         wmlu.show_list(data_out)
         self.assertAllClose(expected_data,data_out,atol=1e-4)
示例#13
0
def sample_in_dir(dir_path, nr, split_nr=None):
    '''
    sample data in dir_path's sub dirs
    sample nr images in each sub dir, if split_nr is not None, sampled nr images will be split 
    to split_nr part and saved in different dir
    '''
    res = {}
    dirs = wmlu.get_subdir_in_dir(dir_path, absolute_path=True)
    print(f"Find dirs in {dir_path}")
    wmlu.show_list(dirs)

    for dir in dirs:
        data = sample_in_one_dir(dir, nr)
        if split_nr is None:
            append_to_dict(res, 0, data)
        else:
            data = wmlu.list_to_2dlistv2(data, split_nr)
            for i, d in enumerate(data):
                append_to_dict(res, i, d)

    return res
示例#14
0
    def __call__(self, dir_path):
        dir_path = os.path.abspath(dir_path)
        while True:
            files = wmlu.recurse_get_filepath_in_dir(dir_path,
                                                     suffix=".tar.gz")
            wmlu.show_list(files)
            process_nr = 0
            for file in files:
                if file in self.history:
                    continue
                ckpt_file = self.extract_data(file)
                if ckpt_file is None:
                    continue
                print("process file {}.".format(file))
                self.history.append(file)
                if self.evaler is not None:
                    result, info = self.evaler(ckpt_file)
                else:
                    result, info = self.eval(ckpt_file)
                if result < 0.01:
                    print("Unnormal result {}, ignored.".format(result))
                    continue
                if result < self.best_result:
                    print(
                        "{} not the best result, best result is {}/{}, achieved at {}, skip backup."
                        .format(file, self.best_result, self.best_file,
                                self.best_result_time))
                    continue
                print("RESULT:", self.best_result, result)
                print("New best result {}, {}.".format(file, info))
                self.best_file = file
                self.best_result = result
                self.best_result_time = time.strftime("%m-%d %H:%M:%S",
                                                      time.localtime())
                self.best_result_t = time.time()

            if process_nr == 0:
                print("sleep for 30 seconds.")
                sys.stdout.flush()
                time.sleep(30)
            continue

        new_mask = odm.dense_mask_to_sparse_mask(binary_mask,category_ids,default_label=255)
        base_name = wmlu.base_name(full_path)+".png"
        save_path = os.path.join(save_dir,base_name)
        new_mask = new_mask.astype(np.uint8)
        if os.path.exists(save_path):
            print(f"WARNING: File {save_path} exists.")
        cv2.imwrite(save_path,new_mask)
        sys.stdout.write(f"\r{i}")


if __name__ == "__main__":
    data_dir ="/home/wj/ai/mldata/mapillary_vistas/"
    save_dir = os.path.join(data_dir,'boe_labels_validation')
    name_to_id_dict = update_name_to_id(name_to_id_dict,data_dir)
    idxs = list(range(0,18049,50))
    r_idxs = []
    for i in range(len(idxs)-1):
        r_idxs.append([idxs[i],idxs[i+1]])
    wmlu.show_list(r_idxs)
    pool = Pool(10)
    def fun(d):
        trans_data(data_dir,save_dir,d[0],d[1])
    res = list(pool.map(fun,r_idxs))
    pool.close()
    pool.join()
    print(res)
    #list(map(fun,r_idxs))

示例#16
0
        ddir = args.out_dir
        tdir = wmlu.base_name(sd, process_suffix=False)
        move = False
        if args.reverse:
            if tdir not in base_names:
                move = True
        else:
            if tdir in base_names:
                move = True

        if move:
            dirs_need_to_move.append(sd)
            move_data.append((rdir, ddir))
        else:
            dirs_dont_move.append(sd)

    print(f"dirs don't need move:")
    wmlu.show_list(dirs_dont_move)
    print(f"Total don't move dir {len(dirs_dont_move)}")
    print(f"dirs need move:")
    wmlu.show_list(dirs_need_to_move)
    print(f"Total dirs need to move {len(dirs_need_to_move)}")

    ans = input("Move dirs?[y/n]")
    if ans == 'y':
        wmlu.create_empty_dir(args.out_dir, remove_if_exists=False)
        for rdir, ddir in move_data:
            cmd = f"mv  \"{rdir}\" \"{ddir}\""
            print(cmd)
            os.system(cmd)
示例#17
0
文件: wnn.py 项目: vghost2008/wml
def restore_variables(sess,path,exclude=None,only_scope=None,silent=False,restore_evckp=True,value_key=None,exclude_var=None,extend_vars=None,global_scopes=None,verbose=False):
    #if restore_evckp and os.path.isdir(path):
    #    evt.WEvalModel.restore_ckp(FLAGS.check_point_dir)
    bn_ops = tf.get_collection(key=tf.GraphKeys.UPDATE_OPS)
    bn_ops = [x.name for x in bn_ops]
    print("batch_norm_ops.")
    wmlu.show_list(bn_ops)
    if exclude is None and exclude_var is not None:
        exclude = exclude_var
    file_path = wmlt.get_ckpt_file_path(path)
    if file_path is None:
        return False
    print(f"resotre from {file_path}.")
    for v in tf.global_variables():
        if "moving_mean" in v.name or "moving_variance" in v.name:
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES,v)
    variables = []
    variables_restored_in_ckpt = []
    variables0 = restore_variables_by_key(sess,file_path,exclude,only_scope,key=tf.GraphKeys.TRAINABLE_VARIABLES,name="train",silent=silent,value_key=value_key)

    if len(variables0)>1:
        if isinstance(variables0[1],list):
            for v in variables0[1]:
                variables.append(v.name)
                variables_restored_in_ckpt.append(v.name)
        else:
            for k,v in variables0[1].items():
                variables.append(v.name)
                variables_restored_in_ckpt.append(k)
    variables1 = restore_variables_by_key(sess,file_path,exclude,only_scope,key=tf.GraphKeys.MOVING_AVERAGE_VARIABLES,name='moving',silent=silent,value_key=value_key)
    if len(variables1)>1:
        if isinstance(variables1[1],list):
            for v in variables1[1]:
                variables.append(v.name)
                variables_restored_in_ckpt.append(v.name)
        else:
            for k,v in variables1[1].items():
                variables.append(v.name)
                variables_restored_in_ckpt.append(k)

    if global_scopes is not None:
        variables2 = restore_variables_by_key(sess,file_path,None,global_scopes,key=tf.GraphKeys.GLOBAL_VARIABLES,name='global_variables',silent=silent,value_key=value_key)
        if len(variables2)>1:
            if isinstance(variables2[1],list):
                for v in variables2[1]:
                    variables.append(v.name)
                    variables_restored_in_ckpt.append(v.name)
            else:
                for k,v in variables2[1].items():
                    variables.append(v.name)
                    variables_restored_in_ckpt.append(k)

    if extend_vars is not None:
        variables2 = restore_variables_by_var_list(sess,file_path,extend_vars)
        if len(variables2)>1:
            if isinstance(variables2[1],list):
                for v in variables2[1]:
                    variables.append(v.name)
                    variables_restored_in_ckpt.append(v.name)
            else:
                for k,v in variables2[1].items():
                    variables.append(v.name)
                    variables_restored_in_ckpt.append(k)

    for i,v in enumerate(variables):
        index = v.find(':')
        if index>0:
            variables[i] = variables[i][:index]
    for i,v in enumerate(variables_restored_in_ckpt):
        index = v.find(':')
        if index>0:
            variables_restored_in_ckpt[i] = variables_restored_in_ckpt[i][:index]

    unrestored_variables = wmlt.get_variables_unrestored(variables_restored_in_ckpt,file_path,
                                                         exclude_var="Adam")
    unrestored_variables0 = wmlt.get_variables_unrestoredv1(variables,exclude_var="Adam")
    if not verbose:
        def v_filter(x:str):
            #return (not x.endswith("ExponentialMovingAverage")) and (not x.endswith("/u"))
            return (not x.endswith("ExponentialMovingAverage")) and (not x.endswith("Momentum"))
        unrestored_variables = filter(v_filter,unrestored_variables)
        unrestored_variables0 = filter(v_filter,unrestored_variables0)
    show_values(unrestored_variables, "Unrestored variables in ckpt")
    show_values(unrestored_variables0, "Unrestored variables of global variables")
    return True
示例#18
0
    exclude_dir = args.exc_dir
    suffix = args.ext
    
    files0 = wmlu.recurse_get_filepath_in_dir(src_dir,suffix=suffix)
    files1 = wmlu.recurse_get_filepath_in_dir(exclude_dir,suffix=suffix)
    files1 = [os.path.basename(file) for file in files1]
    total_skip = 0
    total_remove = 0
    files_to_remove = []
    for file in files0:
        base_name = os.path.basename(file)
        if base_name not in files1:
            print(f"Skip {base_name}")
            total_skip += 1
        else:
            print(f"Remove {file}")
            total_remove += 1
            files_to_remove.append(file)
    
    print(f"Files need to remove:")
    wmlu.show_list(files_to_remove)
    res = input("remove [y/n]")
    if res != 'y':
        print(f"Cancel.")
        exit(0)

    for file in files_to_remove:
        remove_file(file)
    
    print(f"Total files in src dir {len(files0)}, total files in exclude dir {len(files1)}.")
    print(f"Total skip {total_skip}, total remove {total_remove}, total process {total_skip+total_remove}")
示例#19
0
    def build_net(self):
        if not os.path.exists(self.log_dir):
            wmlu.create_empty_dir(self.log_dir)
        if not os.path.exists(self.ckpt_dir):
            wmlu.create_empty_dir(self.ckpt_dir)
        with tf.device(":/cpu:0"):
            data = self.data.get_next()
        DataLoader.detection_image_summary(data, name="data_source")
        self.input_data = data
        '''if self.cfg.GLOBAL.DEBUG:
            data[IMAGE] = tf.Print(data[IMAGE],[tf.shape(data[IMAGE]),data[ORG_HEIGHT],data[ORG_WIDTH],data[HEIGHT],data[WIDTH]],summarize=100,
                                   name="XXXXX")'''
        self.res_data, loss_dict = self.model.forward(data)
        if self.model.is_training:
            for k, v in loss_dict.items():
                tf.summary.scalar(f"loss/{k}", v)
                v = tf.cond(tf.logical_or(tf.is_nan(v), tf.is_inf(v)),
                            lambda: tf.zeros_like(v), lambda: v)
                tf.losses.add_loss(v)
        elif self.cfg.GLOBAL.SUMMARY_LEVEL <= SummaryLevel.RESEARCH:
            research = self.cfg.GLOBAL.RESEARCH
            if 'result_classes' in research:
                print("replace labels with gtlabels.")
                labels = odt.replace_with_gtlabels(
                    bboxes=self.res_data[RD_BOXES],
                    labels=self.res_data[RD_LABELS],
                    length=self.res_data[RD_LENGTH],
                    gtbboxes=data[GT_BOXES],
                    gtlabels=data[GT_LABELS],
                    gtlength=data[GT_LENGTH])
                self.res_data[RD_LABELS] = labels

            if 'result_bboxes' in research:
                print("replace bboxes with gtbboxes.")
                bboxes = odt.replace_with_gtbboxes(
                    bboxes=self.res_data[RD_BOXES],
                    labels=self.res_data[RD_LABELS],
                    length=self.res_data[RD_LENGTH],
                    gtbboxes=data[GT_BOXES],
                    gtlabels=data[GT_LABELS],
                    gtlength=data[GT_LENGTH])
                self.res_data[RD_BOXES] = bboxes

        self.loss_dict = loss_dict

        if not self.model.is_training and self.cfg.GLOBAL.GPU_MEM_FRACTION > 0.1:
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.
                                        cfg.GLOBAL.GPU_MEM_FRACTION)
            config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
        else:
            config = tf.ConfigProto(allow_soft_placement=True)
        if not self.model.is_training and self.cfg.GLOBAL.GPU_MEM_FRACTION <= 0.1:
            config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)
        self.top_variable_name_scope = "Model"

        if self.model.is_training:
            steps = self.cfg.SOLVER.STEPS
            print("Train steps:", steps)
            lr = wnn.build_learning_rate(
                self.cfg.SOLVER.BASE_LR,
                global_step=self.global_step,
                lr_decay_type=self.cfg.SOLVER.LR_DECAY_TYPE,
                steps=steps,
                decay_factor=self.cfg.SOLVER.LR_DECAY_FACTOR,
                total_steps=steps[-1],
                warmup_steps=self.cfg.SOLVER.WARMUP_ITERS)
            tf.summary.scalar("lr", lr)
            opt = wnn.str2optimizer("Momentum", lr, momentum=0.9)
            self.max_train_step = steps[-1]
            self.train_op, self.total_loss, self.variables_to_train = wnn.nget_train_op(
                self.global_step,
                optimizer=opt,
                clip_norm=self.cfg.SOLVER.CLIP_NORM)
            print("variables to train:")
            wmlu.show_list(self.variables_to_train)
            for v in self.variables_to_train:
                wsummary.histogram_or_scalar(v, v.name[:-2])
            wnn.log_moving_variable()

            self.saver = tf.train.Saver(max_to_keep=100)
            tf.summary.scalar(self.cfg.GLOBAL.PROJ_NAME + "_total_loss",
                              self.total_loss)

        self.summary = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        init = tf.global_variables_initializer()
        self.sess.run(init)
        print("batch_norm_ops.")
        wmlu.show_list(
            [x.name for x in tf.get_collection(tf.GraphKeys.UPDATE_OPS)])
示例#20
0
    def build_net_run_on_multi_gpus_nccl(self):
        if not os.path.exists(self.log_dir):
            wmlu.create_empty_dir(self.log_dir)
        if not os.path.exists(self.ckpt_dir):
            wmlu.create_empty_dir(self.ckpt_dir)
        '''if self.cfg.GLOBAL.DEBUG:
            data[IMAGE] = tf.Print(data[IMAGE],[tf.shape(data[IMAGE]),data[ORG_HEIGHT],data[ORG_WIDTH],data[HEIGHT],data[WIDTH]],summarize=100,
                                   name="XXXXX")'''
        all_loss_dict = {}
        steps = self.cfg.SOLVER.STEPS
        print("Train steps:", steps)
        lr = wnn.build_learning_rate(
            self.cfg.SOLVER.BASE_LR,
            global_step=self.global_step,
            lr_decay_type=self.cfg.SOLVER.LR_DECAY_TYPE,
            steps=steps,
            decay_factor=self.cfg.SOLVER.LR_DECAY_FACTOR,
            total_steps=steps[-1],
            min_lr=1e-6,
            warmup_steps=self.cfg.SOLVER.WARMUP_ITERS)
        tf.summary.scalar("lr", lr)
        self.max_train_step = steps[-1]

        if self.cfg.SOLVER.OPTIMIZER == "Momentum":
            opt = wnn.str2optimizer(
                "Momentum", lr, momentum=self.cfg.SOLVER.OPTIMIZER_momentum)
        else:
            opt = wnn.str2optimizer(self.cfg.SOLVER.OPTIMIZER, lr)

        tower_grads = []
        if len(self.gpus) == 0:
            self.gpus = [0]
        if len(self.cfg.SOLVER.TRAIN_SCOPES) > 1:
            train_scopes = self.cfg.SOLVER.TRAIN_SCOPES
        else:
            train_scopes = None
        if len(self.cfg.SOLVER.TRAIN_REPATTERN) > 1:
            train_repattern = self.cfg.SOLVER.TRAIN_REPATTERN
        else:
            train_repattern = None

        for i in range(len(self.gpus)):
            scope = tf.get_variable_scope()
            if i > 0:
                #scope._reuse = tf.AUTO_REUSE
                scope.reuse_variables()
            with tf.device(f"/gpu:{i}"):
                with tf.device(":/cpu:0"):
                    data = self.data.get_next()

                self.input_data = data
                with tf.name_scope(f"GPU{self.gpus[i]}"):
                    with tf.device(":/cpu:0"):
                        DataLoader.detection_image_summary(
                            data, name=f"data_source{i}")

                    self.res_data, loss_dict = self.model.forward(data)
                loss_values = []
                for k, v in loss_dict.items():
                    all_loss_dict[k + f"_stage{i}"] = v
                    tf.summary.scalar(f"loss/{k}", v)
                    ##
                    #v = tf.Print(v,[k,tf.is_nan(v), tf.is_inf(v)])
                    ##
                    v = tf.cond(tf.logical_or(tf.is_nan(v), tf.is_inf(v)),
                                lambda: tf.zeros_like(v), lambda: v)
                    loss_values.append(v)

                scope._reuse = tf.AUTO_REUSE
                '''if (i==0) and len(tf.get_collection(GRADIENT_DEBUG_COLLECTION))>0:
                    total_loss_sum = tf.add_n(loss_values)
                    xs = tf.get_collection(GRADIENT_DEBUG_COLLECTION)
                    grads = tf.gradients(total_loss_sum,xs)
                    grads = [tf.reduce_sum(tf.abs(x)) for x in grads]
                    loss_values[0] = tf.Print(loss_values[0],grads+["grads"],summarize=100)'''

                grads, total_loss, variables_to_train = wnn.nget_train_opv3(
                    optimizer=opt,
                    loss=loss_values,
                    scopes=train_scopes,
                    re_pattern=train_repattern)
                #
                if self.cfg.SOLVER.FILTER_NAN_AND_INF_GRADS:
                    grads = [list(x) for x in grads]
                    for i, (g, v) in enumerate(grads):
                        try:
                            if g is not None:
                                g = tf.where(
                                    tf.logical_or(tf.is_nan(g), tf.is_inf(g)),
                                    tf.random_normal(
                                        shape=wmlt.
                                        combined_static_and_dynamic_shape(g),
                                        stddev=1e-5), g)
                        except:
                            print(f"Error {g}/{v}")
                            raise Exception("Error")
                        grads[i][0] = g
                #
                tower_grads.append(grads)
        ########################
        '''tower_grads[0] = [list(x) for x in tower_grads[0]]
        for i,(g,v) in enumerate(tower_grads[0]):
            tower_grads[0][i][0] = tf.Print(g,["B_"+v.name,tf.reduce_min(g),tf.reduce_mean(g),tf.reduce_max(g)])'''
        ########################

        if self.cfg.SOLVER.CLIP_NORM > 1:
            avg_grads = wnn.average_grads_nccl(
                tower_grads, clip_norm=self.cfg.SOLVER.CLIP_NORM)
        else:
            avg_grads = wnn.average_grads_nccl(tower_grads, clip_norm=None)
        '''avg_grads = [list(x) for x in avg_grads]
        for i,(g,v) in enumerate(avg_grads):
            avg_grads[i][0] = tf.Print(g,[v.name,tf.reduce_min(g),tf.reduce_mean(g),tf.reduce_max(g)])'''

        opt0 = wnn.apply_gradientsv3(avg_grads, self.global_step, opt)
        opt1 = wnn.get_batch_norm_ops()
        self.train_op = tf.group(opt0, opt1)

        self.total_loss, self.variables_to_train = total_loss, variables_to_train

        self.loss_dict = all_loss_dict

        config = tf.ConfigProto(allow_soft_placement=True)
        #config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)

        if self.debug_tf:
            self.sess = tfdbg.LocalCLIDebugWrapperSession(self.sess)

        print("variables to train:")
        wmlu.show_list(self.variables_to_train)
        for v in self.variables_to_train:
            wsummary.histogram_or_scalar(v, v.name[:-2])
        wnn.log_moving_variable()

        self.saver = tf.train.Saver(max_to_keep=100)
        tf.summary.scalar("total_loss", self.total_loss)

        self.summary = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(self.log_dir,
                                                    self.sess.graph)
        init = tf.global_variables_initializer()
        self.sess.run(init)
        print("batch_norm_ops.")
        wmlu.show_list(
            [x.name for x in tf.get_collection(tf.GraphKeys.UPDATE_OPS)])