def main(): flow.config.gpu_device_num(args.gpu_num_per_node) if args.use_fp16 and (args.num_nodes * args.gpu_num_per_node) > 1: flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False) if args.nccl_fusion_threshold_mb: flow.config.collective_boxing.nccl_fusion_threshold_mb(args.nccl_fusion_threshold_mb) if args.nccl_fusion_max_ops: flow.config.collective_boxing.nccl_fusion_max_ops(args.nccl_fusion_max_ops) if args.num_nodes > 1: assert args.num_nodes <= len(args.node_ips) flow.env.ctrl_port(12138) nodes = [] for ip in args.node_ips: addr_dict = {} addr_dict["addr"] = ip nodes.append(addr_dict) flow.env.machine(nodes) flow.env.log_dir(args.log_dir) check_point = flow.train.CheckPoint() if not args.model_load_dir: print("Init model on demand.") check_point.init() else: print("Loading model from {}".format(args.model_load_dir)) check_point.load(args.model_load_dir) train_metric = TrainMetric( desc="train", calculate_batches=1, batch_size=args.train_batch_size ) for step in range(args.total_batch_num): # train insightface_train_job().async_get(train_metric.metric_cb(step)) # validation if ( args.do_validataion_while_train and (step + 1) % args.validataion_interval == 0 ): for ds in ["lfw", "cfp_fp", "agedb_30"]: issame_list, embeddings_list = do_validation(dataset=ds) validation_util.cal_validation_metrics( embeddings_list, issame_list, nrof_folds=args.nrof_folds, ) # snapshot if (step + 1) % args.num_of_batches_in_snapshot == 0: check_point.save( args.model_save_dir + "/snapshot_" + str(step // args.num_of_batches_in_snapshot) )
def main(): args = get_val_args() flow.env.log_dir(args.log_dir) flow.config.gpu_device_num(args.device_num_per_node) # validation print("args: ", args) validator = Validator(args) validator.load_checkpoint() for ds in config.val_targets: issame_list, embeddings_list = validator.do_validation(dataset=ds) validation_util.cal_validation_metrics( embeddings_list, issame_list, nrof_folds=args.nrof_folds, )
def main(): flow.env.log_dir(args.log_dir) flow.config.gpu_device_num(args.gpu_num_per_node) check_point = flow.train.CheckPoint() print("Loading model from {}".format(args.model_load_dir)) check_point.load(args.model_load_dir) # validation for ds in ["lfw", "cfp_fp", "agedb_30"]: issame_list, embeddings_list = do_validation(dataset=ds) validation_util.cal_validation_metrics( embeddings_list, issame_list, nrof_folds=args.nrof_folds, )
def main(args): flow.config.gpu_device_num(args.device_num_per_node) print("gpu num: ", args.device_num_per_node) if not os.path.exists(args.models_root): os.makedirs(args.models_root) def IsFileOrNonEmptyDir(path): if os.path.isfile(path): return True if os.path.isdir(path) and len(os.listdir(path)) != 0: return True return False assert not IsFileOrNonEmptyDir( args.models_root), "Non-empty directory {} already exists!".format( args.models_root) prefix = os.path.join(args.models_root, "%s-%s-%s" % (args.network, args.loss, args.dataset), "model") prefix_dir = os.path.dirname(prefix) print("prefix: ", prefix) if not os.path.exists(prefix_dir): os.makedirs(prefix_dir) default.num_nodes = args.num_nodes default.node_ips = args.node_ips if args.num_nodes > 1: assert args.num_nodes <= len( args.node_ips ), "The number of nodes should not be greater than length of node_ips list." flow.env.ctrl_port(12138) nodes = [] for ip in args.node_ips: addr_dict = {} addr_dict["addr"] = ip nodes.append(addr_dict) flow.env.machine(nodes) if config.data_format.upper() != "NCHW" and config.data_format.upper( ) != "NHWC": raise ValueError("Invalid data format") flow.env.log_dir(args.log_dir) train_func = make_train_func(args) if args.do_validation_while_train: validator = Validator(args) if os.path.exists(args.model_load_dir): assert os.path.abspath( os.path.dirname(os.path.split( args.model_load_dir)[0])) != os.path.abspath( os.path.join( args.models_root, args.network + "-" + args.loss + "-" + args.dataset) ), "You should specify a new path to save new models." print("Loading model from {}".format(args.model_load_dir)) variables = flow.checkpoint.get(args.model_load_dir) flow.load_variables(variables) print("num_classes ", config.num_classes) print("Called with argument: ", args, config) train_metric = TrainMetric(desc="train", calculate_batches=args.loss_print_frequency, batch_size=args.train_batch_size) lr = args.lr for step in range(args.total_iter_num): # train train_func().async_get(train_metric.metric_cb(step)) # validation if default.do_validation_while_train and ( step + 1) % args.validation_interval == 0: for ds in config.val_targets: issame_list, embeddings_list = validator.do_validation( dataset=ds) validation_util.cal_validation_metrics( embeddings_list, issame_list, nrof_folds=args.nrof_folds, ) if step in args.lr_steps: lr *= 0.1 print("lr_steps: ", step) print("lr change to ", lr) # snapshot if (step + 1) % args.iter_num_in_snapshot == 0: path = os.path.join( prefix_dir, "snapshot_" + str(step // args.iter_num_in_snapshot)) flow.checkpoint.save(path) if args.save_last_snapshot is True: flow.checkpoint.save(os.path.join(prefix_dir, "snapshot_last"))