Example #1
0
def train(model, dataset, epochs):
    # model_path = path_from_model_name(model)

    model = create_model_masked(model)
    history = {"loss": [], "accuracy": [], "val_loss": [], "val_accuracy": []}

    train_model(model, history, dataset, epochs)
Example #2
0
def train_wrapper(args: Namespace) -> None:
    """ Function for training a network. """
    model_name = args.model
    if args.cont:
        model = load_model(model_name)
        history = model.__asf_model_history
        weights = model.get_weights()
        lr_schedule = ExponentialDecay(9.2e-4,
                                       decay_steps=2000,
                                       decay_rate=0.96,
                                       staircase=True)
        # optimizer = model.optimizer
        model.compile(loss=jaccard_distance_loss,
                      optimizer=Adam(learning_rate=lr_schedule),
                      metrics=['accuracy', MeanIoU(num_classes=2)])
        model.set_weights(weights)
    #     model.compile(
    #         loss='binary_crossentropy', optimizer='adam', metrics=["accuracy"]
    # )
    else:
        model_path = path_from_model_name(model_name)
        if not args.overwrite and os.path.isfile(model_path):
            print(f"File {model_name} already exists!")
            return

        # model = create_model_masked(model_name)
        model = create_cdl_model_masked(model_name)
        history = {'loss': [], 'accuracy': [], "mean_io_u": []}

    train_model(model, history, args.dataset, args.epochs)
Example #3
0
def train_wrapper(args: Namespace) -> None:
    """ Function for training a network. """
    model_name = args.model
    if args.cont:
        model = load_model(model_name)
        history = model.__asf_model_history
    else:
        model_path = path_from_model_name(model_name)
        if not args.overwrite and os.path.isfile(model_path):
            print(f"File {model_name} already exists!")
            return

        model = create_model_masked(model_name)
        history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []}

    train_model(model, history, args.dataset, args.epochs)
Example #4
0
def train_wrapper(args: Namespace) -> None:
    """Function for training a network"""
    data_type = dataset_type(args.dataset)
    model_name = args.model
    if args.cont:
        model = load_model(model_name)
        history = model.__asf_model_history
    else:
        model_path = path_from_model_name(model_name)
        if not args.overwrite and os.path.isfile(model_path):
            print(f"File {model_name} already exists!")
            return

        model = create_model(model_name, data_type)
        history = {"loss": [], "acc": [], "val_loss": [], "val_acc": []}

    if model_type(model) != data_type:
        print("ERROR: This dataset is not compatible with your model")
        return

    train_model(model, history, args.dataset, args.epochs)