示例#1
0
def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name,
                mode):

    torch.distributed.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:" + port,
        world_size=ngpus_per_node,
        rank=gpu,
    )

    # Apply config to wandb
    if gpu == 0 and config["wandb_opt"]:
        wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
        wandb.config.update(config)

    batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
    config["train"]["batch_size"] = batch_size
    config = DotDict(config)

    # Start train
    trainer = Trainer(config, gpu, mode)
    trainer.train(buffer_dict)

    if gpu == 0:
        if config["wandb_opt"]:
            wandb.finish()

    torch.distributed.barrier()
    torch.distributed.destroy_process_group()
示例#2
0
    def iou_eval(self, dataset, train_step, buffer, model):
        test_config = DotDict(self.config.test[dataset])

        val_result_dir = os.path.join(
            self.config.results_dir, "{}/{}".format(dataset + "_iou",
                                                    str(train_step)))

        evaluator = DetectionIoUEvaluator()

        metrics = main_eval(
            None,
            self.config.train.backbone,
            test_config,
            evaluator,
            val_result_dir,
            buffer,
            model,
            self.mode,
        )
        if self.gpu == 0 and self.config.wandb_opt:
            wandb.log({
                "{} iou Recall".format(dataset):
                np.round(metrics["recall"], 3),
                "{} iou Precision".format(dataset):
                np.round(metrics["precision"], 3),
                "{} iou F1-score".format(dataset):
                np.round(metrics["hmean"], 3),
            })
示例#3
0
def main():
    parser = argparse.ArgumentParser(description="CRAFT custom data train")
    parser.add_argument(
        "--yaml",
        "--yaml_file_name",
        default="custom_data_train",
        type=str,
        help="Load configuration",
    )
    parser.add_argument(
        "--port", "--use ddp port", default="2346", type=str, help="Port number"
    )

    args = parser.parse_args()

    # load configure
    exp_name = args.yaml
    config = load_yaml(args.yaml)

    print("-" * 20 + " Options " + "-" * 20)
    print(yaml.dump(config))
    print("-" * 40)

    # Make result_dir
    res_dir = os.path.join(config["results_dir"], args.yaml)
    config["results_dir"] = res_dir
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    # Duplicate yaml file to result_dir
    shutil.copy(
        "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
    )

    if config["mode"] == "weak_supervision":
        mode = "weak_supervision"
    else:
        mode = None


    # Apply config to wandb
    if config["wandb_opt"]:
        wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
        wandb.config.update(config)

    config = DotDict(config)

    # Start train
    buffer_dict = {"custom_data":None}
    trainer = Trainer(config, 0, mode)
    trainer.train(buffer_dict)

    if config["wandb_opt"]:
        wandb.finish()
示例#4
0
def cal_eval(config, data, res_dir_name, opt, mode):
    evaluator = DetectionIoUEvaluator()
    test_config = DotDict(config.test[data])
    res_dir = os.path.join(os.path.join("exp", args.yaml),
                           "{}".format(res_dir_name))

    if opt == "iou_eval":
        main_eval(
            config.test.trained_model,
            config.train.backbone,
            test_config,
            evaluator,
            res_dir,
            buffer=None,
            model=None,
            mode=mode,
        )
    else:
        print("Undefined evaluation")
示例#5
0

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval")
    parser.add_argument(
        "--yaml",
        "--yaml_file_name",
        default="custom_data_train",
        type=str,
        help="Load configuration",
    )
    args = parser.parse_args()

    # load configure
    config = load_yaml(args.yaml)
    config = DotDict(config)

    if config["wandb_opt"]:
        wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml)
        wandb.config.update(config)

    val_result_dir_name = args.yaml
    cal_eval(
        config,
        "custom_data",
        val_result_dir_name + "-ic15-iou",
        opt="iou_eval",
        mode=None,
    )