예제 #1
0
def main():

    flow.config.gpu_device_num(args.gpu_num_per_node)

    if args.use_fp16 and (args.num_nodes * args.gpu_num_per_node) > 1:
        flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)
    if args.nccl_fusion_threshold_mb:
        flow.config.collective_boxing.nccl_fusion_threshold_mb(args.nccl_fusion_threshold_mb)
    if args.nccl_fusion_max_ops:
        flow.config.collective_boxing.nccl_fusion_max_ops(args.nccl_fusion_max_ops)

    if args.num_nodes > 1:
        assert args.num_nodes <= len(args.node_ips)
        flow.env.ctrl_port(12138)
        nodes = []
        for ip in args.node_ips:
            addr_dict = {}
            addr_dict["addr"] = ip
            nodes.append(addr_dict)

        flow.env.machine(nodes)

    flow.env.log_dir(args.log_dir)
    check_point = flow.train.CheckPoint()
    if not args.model_load_dir:
        print("Init model on demand.")
        check_point.init()
    else:
        print("Loading model from {}".format(args.model_load_dir))
        check_point.load(args.model_load_dir)

    train_metric = TrainMetric(
        desc="train", calculate_batches=1, batch_size=args.train_batch_size
    )

    for step in range(args.total_batch_num):
        # train
        insightface_train_job().async_get(train_metric.metric_cb(step))

        # validation
        if (
            args.do_validataion_while_train
            and (step + 1) % args.validataion_interval == 0
        ):
            for ds in ["lfw", "cfp_fp", "agedb_30"]:
                issame_list, embeddings_list = do_validation(dataset=ds)
                validation_util.cal_validation_metrics(
                    embeddings_list, issame_list, nrof_folds=args.nrof_folds,
                )

        # snapshot
        if (step + 1) % args.num_of_batches_in_snapshot == 0:
            check_point.save(
                args.model_save_dir
                + "/snapshot_"
                + str(step // args.num_of_batches_in_snapshot)
            )
예제 #2
0
def main():
    args = get_val_args()
    flow.env.log_dir(args.log_dir)
    flow.config.gpu_device_num(args.device_num_per_node)

    # validation
    print("args: ", args)
    validator = Validator(args)
    validator.load_checkpoint()
    for ds in config.val_targets:
        issame_list, embeddings_list = validator.do_validation(dataset=ds)
        validation_util.cal_validation_metrics(
            embeddings_list, issame_list, nrof_folds=args.nrof_folds,
        )
def main():
    flow.env.log_dir(args.log_dir)
    flow.config.gpu_device_num(args.gpu_num_per_node)

    check_point = flow.train.CheckPoint()
    print("Loading model from {}".format(args.model_load_dir))
    check_point.load(args.model_load_dir)

    # validation
    for ds in ["lfw", "cfp_fp", "agedb_30"]:
        issame_list, embeddings_list = do_validation(dataset=ds)
        validation_util.cal_validation_metrics(
            embeddings_list, issame_list, nrof_folds=args.nrof_folds,
        )
예제 #4
0
def main(args):
    flow.config.gpu_device_num(args.device_num_per_node)
    print("gpu num: ", args.device_num_per_node)
    if not os.path.exists(args.models_root):
        os.makedirs(args.models_root)

    def IsFileOrNonEmptyDir(path):
        if os.path.isfile(path):
            return True
        if os.path.isdir(path) and len(os.listdir(path)) != 0:
            return True
        return False

    assert not IsFileOrNonEmptyDir(
        args.models_root), "Non-empty directory {} already exists!".format(
            args.models_root)
    prefix = os.path.join(args.models_root,
                          "%s-%s-%s" % (args.network, args.loss, args.dataset),
                          "model")
    prefix_dir = os.path.dirname(prefix)
    print("prefix: ", prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)

    default.num_nodes = args.num_nodes
    default.node_ips = args.node_ips
    if args.num_nodes > 1:
        assert args.num_nodes <= len(
            args.node_ips
        ), "The number of nodes should not be greater than length of node_ips list."
        flow.env.ctrl_port(12138)
        nodes = []
        for ip in args.node_ips:
            addr_dict = {}
            addr_dict["addr"] = ip
            nodes.append(addr_dict)

        flow.env.machine(nodes)
    if config.data_format.upper() != "NCHW" and config.data_format.upper(
    ) != "NHWC":
        raise ValueError("Invalid data format")
    flow.env.log_dir(args.log_dir)
    train_func = make_train_func(args)
    if args.do_validation_while_train:
        validator = Validator(args)

    if os.path.exists(args.model_load_dir):
        assert os.path.abspath(
            os.path.dirname(os.path.split(
                args.model_load_dir)[0])) != os.path.abspath(
                    os.path.join(
                        args.models_root,
                        args.network + "-" + args.loss + "-" + args.dataset)
                ), "You should specify a new path to save new models."
        print("Loading model from {}".format(args.model_load_dir))
        variables = flow.checkpoint.get(args.model_load_dir)
        flow.load_variables(variables)

    print("num_classes ", config.num_classes)
    print("Called with argument: ", args, config)
    train_metric = TrainMetric(desc="train",
                               calculate_batches=args.loss_print_frequency,
                               batch_size=args.train_batch_size)
    lr = args.lr

    for step in range(args.total_iter_num):
        # train
        train_func().async_get(train_metric.metric_cb(step))

        # validation
        if default.do_validation_while_train and (
                step + 1) % args.validation_interval == 0:
            for ds in config.val_targets:
                issame_list, embeddings_list = validator.do_validation(
                    dataset=ds)
                validation_util.cal_validation_metrics(
                    embeddings_list,
                    issame_list,
                    nrof_folds=args.nrof_folds,
                )
        if step in args.lr_steps:
            lr *= 0.1
            print("lr_steps: ", step)
            print("lr change to ", lr)

        # snapshot
        if (step + 1) % args.iter_num_in_snapshot == 0:
            path = os.path.join(
                prefix_dir,
                "snapshot_" + str(step // args.iter_num_in_snapshot))
            flow.checkpoint.save(path)

    if args.save_last_snapshot is True:
        flow.checkpoint.save(os.path.join(prefix_dir, "snapshot_last"))