Пример #1
0
def user_info_task() -> (str, str):
    secret_username = flytekit.current_context().secrets.get(
        SECRET_GROUP, USERNAME_SECRET)
    secret_pwd = flytekit.current_context().secrets.get(
        SECRET_GROUP, PASSWORD_SECRET)
    # Please do not print the secret value, this is just a demonstration.
    print(f"{secret_username}={secret_pwd}")
    return secret_username, secret_pwd
def download_data(dataset: str) -> FlyteDirectory:
    # create a directory named 'data'
    print("==============")
    print("Downloading data")
    print("==============")

    working_dir = flytekit.current_context().working_directory
    data_dir = pathlib.Path(os.path.join(working_dir, "data"))
    data_dir.mkdir(exist_ok=True)

    # download the dataset
    download_subprocess = subprocess.run(
        [
            "curl",
            dataset,
        ],
        check=True,
        capture_output=True,
    )

    # untar the data
    subprocess.run(
        [
            "tar",
            "-xz",
            "-C",
            data_dir,
        ],
        input=download_subprocess.stdout,
    )

    # return the directory populated with Rossmann data files
    return FlyteDirectory(path=str(data_dir))
Пример #3
0
def create_entities() -> Tuple[FlyteFile, FlyteDirectory]:
    working_dir = flytekit.current_context().working_directory
    flytefile = os.path.join(working_dir, "test.txt")
    os.open(flytefile, os.O_CREAT)
    flytedir = os.path.join(working_dir, "testdata")
    os.makedirs(flytedir, exist_ok=True)
    return flytefile, flytedir
Пример #4
0
def secret_file_task() -> (str, str):
    # SM here is a handle to the secrets manager
    sm = flytekit.current_context().secrets
    f = sm.get_secrets_file(SECRET_GROUP, SECRET_NAME)
    secret_val = sm.get(SECRET_GROUP, SECRET_NAME)
    # returning the filename and the secret_val
    return f, secret_val
Пример #5
0
 def interpolate(
     self,
     tmpl: str,
     inputs: typing.Optional[typing.Dict[str, str]] = None,
     outputs: typing.Optional[typing.Dict[str, str]] = None,
 ) -> str:
     """
     Interpolate python formatted string templates with variables from the input and output
     argument dicts. The result is non destructive towards the given template string.
     """
     inputs = inputs or {}
     outputs = outputs or {}
     inputs = AttrDict(inputs)
     outputs = AttrDict(outputs)
     consolidated_args = {
         "inputs": inputs,
         "outputs": outputs,
         "ctx": flytekit.current_context(),
     }
     try:
         return self._Formatter().format(tmpl, **consolidated_args)
     except KeyError as e:
         raise ValueError(
             f"Variable {e} in Query not found in inputs {consolidated_args.keys()}"
         )
Пример #6
0
def hello_spark(partitions: int) -> float:
    print("Starting Spark with Partitions: {}".format(partitions))

    n = 100000 * partitions
    sess = flytekit.current_context().spark_session
    count = sess.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
    pi_val = 4.0 * count / n
    print("Pi val is :{}".format(pi_val))
    return pi_val
Пример #7
0
def download_files() -> FlyteDirectory:
    working_dir = flytekit.current_context().working_directory
    pp = pathlib.Path(os.path.join(working_dir, "images"))
    pp.mkdir(exist_ok=True)
    for idx, remote_location in enumerate(default_images):
        local_image = os.path.join(working_dir, "images", f"image_{idx}.jpg")
        urllib.request.urlretrieve(remote_location, local_image)

    return FlyteDirectory(path=os.path.join(working_dir, "images"))
Пример #8
0
def create_spark_df() -> my_schema:
    """
    This spark program returns a spark dataset that conforms to the defined schema. Failure to do so should result
    in a runtime error. TODO: runtime error enforcement
    """
    sess = flytekit.current_context().spark_session
    return sess.createDataFrame(
        [("Alice", 5), ("Bob", 10), ("Charlie", 15), ], my_schema.column_names(),
    )
Пример #9
0
 def execute(self, **kwargs) -> typing.Any:
     if self._secret_connect_args is not None:
         for key, secret in self._secret_connect_args.items():
             value = current_context().secrets.get(secret.group, secret.key)
             self._connect_args[key] = value
     engine = create_engine(self._uri,
                            connect_args=self._connect_args,
                            echo=False)
     print(f"Connecting to db {self._uri}")
     with engine.begin() as connection:
         df = pd.read_sql_query(self.get_query(**kwargs), connection)
     return df
def mnist_pytorch_job(hp: Hyperparameters) -> PythonPickledFile:
    # pytorch's save() function does not create a path if the path specified does not exist
    # therefore we must pass an existing path

    ctx = flytekit.current_context()
    data_dir = os.path.join(ctx.working_directory, "data")
    model_dir = os.path.join(ctx.working_directory, "model")
    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    args = TrainingArgs(
        hosts=ctx.distributed_training_context.hosts,
        current_host=ctx.distributed_training_context.current_host,
        num_gpus=torch.cuda.device_count(),
        batch_size=hp.batch_size,
        test_batch_size=hp.test_batch_size,
        epochs=hp.epochs,
        learning_rate=hp.learning_rate,
        sgd_momentum=hp.sgd_momentum,
        seed=hp.seed,
        log_interval=hp.log_interval,
        backend=hp.backend,
        data_dir=data_dir,
        model_dir=model_dir,
    )

    # Data shouldn't be downloaded by the functions called in mp.spawn due to race conditions
    # These can be replaced by Flyte's blob type inputs. Note that the data here are assumed
    # to be accessible via a local path:
    download_training_data(args.data_dir)
    download_test_data(args.data_dir)

    if len(args.hosts) > 1:
        # Config MASTER_ADDR and MASTER_PORT for PyTorch Distributed Training
        os.environ["MASTER_ADDR"] = args.hosts[0]
        os.environ["MASTER_PORT"] = "29500"
        os.environ[
            "NCCL_SOCKET_IFNAME"] = ctx.distributed_training_context.network_interface_name
        os.environ["NCCL_DEBUG"] = "INFO"
        # The function is called as fn(i, *args), where i is the process index and args is the passed
        # through tuple of arguments.
        # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
        mp.spawn(train, nprocs=args.num_gpus, args=(args, ))
    else:
        # Config for Multi GPU with a single instance training
        if args.num_gpus > 1:
            gpu_devices = ",".join(
                [str(gpu_id) for gpu_id in range(args.num_gpus)])
            os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
        train(-1, args)

    pth = os.path.join(model_dir, "model.pth")
    print(f"Returning model @ {pth}")
    return pth
Пример #11
0
    def t1() -> FlyteDirectory:
        user_ctx = flytekit.current_context()
        # Create a local directory to work with
        p = os.path.join(user_ctx.working_directory, "test_wf")
        if os.path.exists(p):
            shutil.rmtree(p)
        pathlib.Path(p).mkdir(parents=True)
        for i in range(1, 6):
            with open(os.path.join(p, f"{i}.txt"), "w") as fh:
                fh.write(f"I'm file {i}\n")

        return FlyteDirectory(p)
Пример #12
0
def download_files(csv_urls: List[str]) -> FlyteDirectory:
    working_dir = flytekit.current_context().working_directory
    local_dir = Path(os.path.join(working_dir, "csv_files"))
    local_dir.mkdir(exist_ok=True)

    # get the number of digits needed to preserve the order of files in the local directory
    zfill_len = len(str(len(csv_urls)))
    for idx, remote_location in enumerate(csv_urls):
        local_image = os.path.join(
            # prefix the file name with the index location of the file in the original csv_urls list
            local_dir,
            f"{str(idx).zfill(zfill_len)}_{os.path.basename(remote_location)}",
        )
        urllib.request.urlretrieve(remote_location, local_image)
    return FlyteDirectory(path=str(local_dir))
Пример #13
0
    def onnx_predict(model_file: ONNXFile) -> JPEGImageFile:
        ort_session = onnxruntime.InferenceSession(model_file.download())

        img = Image.open(
            requests.get(
                "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/cat.jpg",
                stream=True).raw)

        resize = transforms.Resize([224, 224])
        img = resize(img)

        img_ycbcr = img.convert("YCbCr")
        img_y, img_cb, img_cr = img_ycbcr.split()

        to_tensor = transforms.ToTensor()
        img_y = to_tensor(img_y)
        img_y.unsqueeze_(0)

        # compute ONNX Runtime output prediction
        ort_inputs = {
            ort_session.get_inputs()[0].name:
            img_y.detach().cpu().numpy()
            if img_y.requires_grad else img_y.cpu().numpy()
        }
        ort_outs = ort_session.run(None, ort_inputs)
        img_out_y = ort_outs[0]

        img_out_y = Image.fromarray(np.uint8(
            (img_out_y[0] * 255.0).clip(0, 255)[0]),
                                    mode="L")

        # get the output image follow post-processing step from PyTorch implementation
        final_img = Image.merge(
            "YCbCr",
            [
                img_out_y,
                img_cb.resize(img_out_y.size, Image.BICUBIC),
                img_cr.resize(img_out_y.size, Image.BICUBIC),
            ],
        ).convert("RGB")

        img_path = Path(flytekit.current_context().working_directory
                        ) / "cat_superres_with_ort.jpg"
        final_img.save(img_path)

        # Save the image, we will compare this with the output image from mobile device
        return JPEGImageFile(path=str(img_path))
Пример #14
0
    def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
        """
        In the case of distributed execution, we check the should_persist_predicate in the configuration to determine
        if the output should be persisted. This is because in distributed training, multiple nodes may produce partial
        outputs and only the user process knows the output that should be generated. They can control the choice using
        the predicate.

        To control if output is generated across every execution, we override the post_execute method and sometimes
        return a None
        """
        if self._is_distributed():
            logging.info("Distributed context detected!")
            dctx = flytekit.current_context().distributed_training_context
            if not self.task_config.should_persist_output(dctx):
                logging.info("output persistence predicate not met, Flytekit will ignore outputs")
                raise IgnoreOutputs(f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}")
        return rval
Пример #15
0
def rotate(image_location: str) -> FlyteFile:
    """
    Download the given image, rotate it by 180 degrees
    """
    working_dir = flytekit.current_context().working_directory
    local_image = os.path.join(working_dir, "incoming.jpg")
    urllib.request.urlretrieve(image_location, local_image)
    img = cv2.imread(local_image, 0)
    if img is None:
        raise Exception("Failed to read image")
    (h, w) = img.shape[:2]
    center = (w / 2, h / 2)
    mat = cv2.getRotationMatrix2D(center, 180, 1)
    res = cv2.warpAffine(img, mat, (w, h))
    out_path = os.path.join(working_dir, "rotated.jpg")
    cv2.imwrite(out_path, res)
    return FlyteFile["jpg"](path=out_path)
def horovod_spark_task(data_dir: FlyteDirectory, hp: Hyperparameters,
                       work_dir: FlyteDirectory) -> FlyteDirectory:

    max_sales, vocab, train_df, test_df = data_preparation(data_dir, hp)

    # working directory will have the model and predictions as separate files
    working_dir = flytekit.current_context().working_directory

    keras_model = train(
        max_sales,
        vocab,
        hp,
        work_dir,
        train_df,
        working_dir,
    )

    # generate predictions
    return test(keras_model, working_dir, test_df, hp)
Пример #17
0
def test_secrets():
    @task(secret_requests=[Secret("my_group", "my_key")])
    def foo() -> str:
        return flytekit.current_context().secrets.get("my_group", "")

    with pytest.raises(ValueError):
        foo()

    @task(secret_requests=[Secret("group", group_version="v1", key="key")])
    def foo2() -> str:
        return flytekit.current_context().secrets.get("group", "key")

    os.environ[flytekit.current_context().secrets.get_secrets_env_var("group", "key")] = "super-secret-value2"
    assert foo2() == "super-secret-value2"

    with pytest.raises(AssertionError):

        @task(secret_requests=["test"])
        def foo() -> str:
            pass
Пример #18
0
def fit(loc: str, train: pd.DataFrame, val: pd.DataFrame) -> JoblibSerializedFile:

    # fetch the features and target columns from the train dataset
    x = train[train.columns[1:]]
    y = train[train.columns[0]]

    # fetch the features and target columns from the validation dataset
    eval_x = val[val.columns[1:]]
    eval_y = val[val.columns[0]]

    m = XGBRegressor()
    # fit the model to the train data
    m.fit(x, y, eval_set=[(eval_x, eval_y)])

    working_dir = flytekit.current_context().working_directory
    fname = os.path.join(working_dir, f"model-{loc}.joblib.dat")
    joblib.dump(m, fname)

    # return the serialized model
    return JoblibSerializedFile(path=fname)
Пример #19
0
def rotate(image_location: JPEGImageFile, location: str) -> JPEGImageFile:
    """
    Download the given image, rotate it by 180 degrees
    """
    working_dir = flytekit.current_context().working_directory
    image_location.download()
    img = cv2.imread(image_location.path, 0)
    if img is None:
        raise Exception("Failed to read image")
    (h, w) = img.shape[:2]
    center = (w / 2, h / 2)
    mat = cv2.getRotationMatrix2D(center, 180, 1)
    res = cv2.warpAffine(img, mat, (w, h))
    out_path = os.path.join(
        working_dir,
        f"rotated-{os.path.basename(image_location.path).rsplit('.')[0]}.jpg",
    )
    cv2.imwrite(out_path, res)
    if location:
        return JPEGImageFile(path=out_path, remote_path=location)
    else:
        return JPEGImageFile(path=out_path)
Пример #20
0
def use_checkpoint(n_iterations: int) -> int:
    cp = current_context().checkpoint
    prev = cp.read()
    start = 0
    if prev:
        start = int(prev.decode())

    # create a failure interval so we can create failures for every 'n' iterations and then succeed within
    # configured retries
    failure_interval = n_iterations * 1.0 / RETRIES
    i = 0
    for i in range(start, n_iterations):
        # simulate a deterministic failure, for demonstration. We want to show how it eventually completes within
        # the given retries
        if i > start and i % failure_interval == 0:
            raise FlyteRecoverableException(
                f"Failed at iteration {start}, failure_interval {failure_interval}"
            )
        # save progress state. It is also entirely possible save state every few intervals.
        cp.write(f"{i + 1}".encode())

    return i
Пример #21
0
def normalize_columns(
    csv_url: FlyteFile,
    column_names: List[str],
    columns_to_normalize: List[str],
    output_location: str,
) -> FlyteFile:
    # read the data from the raw csv file
    parsed_data = defaultdict(list)
    with open(csv_url, newline="\n") as input_file:
        reader = csv.DictReader(input_file, fieldnames=column_names)
        for row in (x for i, x in enumerate(reader) if i > 0):
            for column in columns_to_normalize:
                parsed_data[column].append(float(row[column].strip()))

    # normalize the data
    normalized_data = defaultdict(list)
    for colname, values in parsed_data.items():
        mean = sum(values) / len(values)
        std = (sum([(x - mean)**2 for x in values]) / len(values))**0.5
        normalized_data[colname] = [(x - mean) / std for x in values]

    # write to local path
    out_path = os.path.join(
        flytekit.current_context().working_directory,
        f"normalized-{os.path.basename(csv_url.path).rsplit('.')[0]}.csv",
    )
    with open(out_path, mode="w") as output_file:
        writer = csv.DictWriter(output_file, fieldnames=columns_to_normalize)
        writer.writeheader()
        for row in zip(*normalized_data.values()):
            writer.writerow(
                {k: row[i]
                 for i, k in enumerate(columns_to_normalize)})

    if output_location:
        return FlyteFile(path=out_path, remote_path=output_location)
    else:
        return FlyteFile(path=out_path)
Пример #22
0
def test_input_substitution_files_ctx():
    sec = flytekit.current_context().secrets
    envvar = sec.get_secrets_env_var("group", "key")
    os.environ[envvar] = "value"
    assert sec.get("group", "key") == "value"

    t = ShellTask(
        name="test",
        script="""
            export EXEC={ctx.execution_id}
            export SECRET={ctx.secrets.group.key}
            cat {inputs.f}
            echo "Hello World {inputs.y} on  {inputs.j}"
            """,
        inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
        debug=True,
    )

    if os.name == "nt":
        t._script = t._script.replace("cat", "type").replace("export", "set")

    assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None
    del os.environ[envvar]
Пример #23
0
    def execute_from_model(self, tt: task_models.TaskTemplate,
                           **kwargs) -> typing.Any:
        if tt.custom["secret_connect_args"] is not None:
            for key, secret_dict in tt.custom["secret_connect_args"].items():
                value = current_context().secrets.get(
                    group=secret_dict["group"], key=secret_dict["key"])
                tt.custom["connect_args"][key] = value

        engine = create_engine(tt.custom["uri"],
                               connect_args=tt.custom["connect_args"],
                               echo=False)
        print(f"Connecting to db {tt.custom['uri']}")

        interpolated_query = SQLAlchemyTask.interpolate_query(
            tt.custom["query_template"], **kwargs)
        print(f"Interpolated query {interpolated_query}")
        with engine.begin() as connection:
            df = None
            if tt.interface.outputs:
                df = pd.read_sql_query(interpolated_query, connection)
            else:
                pandasSQL_builder(connection).execute(interpolated_query)
        return df
Пример #24
0
def horovod_train_task(batch_size: int, buffer_size: int,
                       dataset_size: int) -> FlyteDirectory:
    """
    :param batch_size: Represents the number of consecutive elements of this dataset to combine in a single batch.
    :param buffer_size: Defines the size of the buffer used to hold elements of the dataset used for training.
    :param dataset_size: The number of elements of this dataset that should be taken to form the new dataset when
        running batched training.
    """
    hvd.init()

    (mnist_images,
     mnist_labels), _ = tf.keras.datasets.mnist.load_data(path="mnist-%d.npz" %
                                                          hvd.rank())

    dataset = tf.data.Dataset.from_tensor_slices((
        tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
        tf.cast(mnist_labels, tf.int64),
    ))
    dataset = dataset.repeat().shuffle(buffer_size).batch(batch_size)

    mnist_model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, [3, 3], activation="relu"),
        tf.keras.layers.Conv2D(64, [3, 3], activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(10, activation="softmax"),
    ])
    loss = tf.losses.SparseCategoricalCrossentropy()

    # Horovod: adjust learning rate based on number of GPUs.
    opt = tf.optimizers.Adam(0.001 * hvd.size())

    checkpoint_dir = ".checkpoint"
    pathlib.Path(checkpoint_dir).mkdir(exist_ok=True)

    checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)

    # Horovod: adjust number of steps based on number of GPUs.
    for batch, (images,
                labels) in enumerate(dataset.take(dataset_size // hvd.size())):
        loss_value = training_step(images, labels, batch == 0, mnist_model,
                                   loss, opt)

        if batch % 10 == 0 and hvd.local_rank() == 0:
            print("Step #%d\tLoss: %.6f" % (batch, loss_value))

    if hvd.rank() != 0:
        raise IgnoreOutputs("I am not rank 0")

    working_dir = flytekit.current_context().working_directory
    checkpoint_prefix = pathlib.Path(os.path.join(working_dir, "checkpoint"))
    checkpoint.save(checkpoint_prefix)

    tf.keras.models.save_model(
        mnist_model,
        str(working_dir),
        overwrite=True,
        include_optimizer=True,
        save_format=None,
        signatures=None,
        options=None,
        save_traces=True,
    )
    return FlyteDirectory(path=str(working_dir))
Пример #25
0
def secret_task() -> str:
    secret_val = flytekit.current_context().secrets.get(
        SECRET_GROUP, SECRET_NAME)
    # Please do not print the secret value, we are doing so just as a demonstration
    print(secret_val)
    return secret_val
Пример #26
0
        y=alt.Y('new_cases_smoothed_per_million:Q', stack=None),
        color=alt.Color('continent:N', scale=alt.Scale(scheme='set1')),
        tooltip='continent:N').interactive().properties(width='container')

    dp.Report(dp.Plot(plot), dp.DataTable(df)).save(path='report.html',open=True)


@task
def transform_data(url: str) -> pandas.DataFrame:
    dataset = pd.read_csv(url)
    df = dataset.groupby(
        ['continent',
         'date'])['new_cases_smoothed_per_million'].mean().reset_index()
    return df


@workflow
def datapane_workflow(url: str):
    df = transform_data(url=url)
    publish_report(df=df)
    print(f"Report is published for {url}")


default_lp = LaunchPlan.get_default_launch_plan(
    current_context(),
    datapane_workflow)

if __name__ == "__main__":
    print(default_lp(url="https://covid.ourworldindata.org/data/owid-covid-data.csv"))

Пример #27
0
def t1(n: int) -> int:
    ctx = flytekit.current_context()
    cp = ctx.checkpoint
    cp.write(bytes(n + 1))
    return n + 1
def data_preparation(
    data_dir: FlyteDirectory, hp: Hyperparameters
) -> Tuple[float, Dict[str, List[Any]], pyspark.sql.DataFrame,
           pyspark.sql.DataFrame]:
    print("================")
    print("Data preparation")
    print("================")

    # 'current_context' gives the handle of specific parameters in ``data_preparation`` task
    spark = flytekit.current_context().spark_session
    data_dir_path = data_dir.remote_source
    # read the CSV data into Spark DataFrame
    train_csv = spark.read.csv("%s/train.csv" % data_dir_path, header=True)
    test_csv = spark.read.csv("%s/test.csv" % data_dir_path, header=True)

    store_csv = spark.read.csv("%s/store.csv" % data_dir_path, header=True)
    store_states_csv = spark.read.csv("%s/store_states.csv" % data_dir_path,
                                      header=True)
    state_names_csv = spark.read.csv("%s/state_names.csv" % data_dir_path,
                                     header=True)
    google_trend_csv = spark.read.csv("%s/googletrend.csv" % data_dir_path,
                                      header=True)
    weather_csv = spark.read.csv("%s/weather.csv" % data_dir_path, header=True)

    # retrieve a sampled subset of the train and test data
    if hp.sample_rate:
        train_csv = train_csv.sample(withReplacement=False,
                                     fraction=hp.sample_rate)
        test_csv = test_csv.sample(withReplacement=False,
                                   fraction=hp.sample_rate)

    # prepare the DataFrames from the CSV files
    train_df = prepare_df(
        train_csv,
        store_csv,
        store_states_csv,
        state_names_csv,
        google_trend_csv,
        weather_csv,
    ).cache()
    test_df = prepare_df(
        test_csv,
        store_csv,
        store_states_csv,
        state_names_csv,
        google_trend_csv,
        weather_csv,
    ).cache()

    # add elapsed times from the data spanning training & test datasets
    elapsed_cols = ["Promo", "StateHoliday", "SchoolHoliday"]
    elapsed = add_elapsed(
        train_df.select("Date", "Store", *elapsed_cols).unionAll(
            test_df.select("Date", "Store", *elapsed_cols)),
        elapsed_cols,
    )

    # join with the elapsed times
    train_df = train_df.join(elapsed, ["Date", "Store"]).select(
        train_df["*"],
        *[
            prefix + col for prefix in ["Before", "After"]
            for col in elapsed_cols
        ],
    )
    test_df = test_df.join(elapsed, ["Date", "Store"]).select(
        test_df["*"],
        *[
            prefix + col for prefix in ["Before", "After"]
            for col in elapsed_cols
        ],
    )

    # filter out zero sales
    train_df = train_df.filter(train_df.Sales > 0)

    print("===================")
    print("Prepared data frame")
    print("===================")
    train_df.show()

    all_cols = CATEGORICAL_COLS + CONTINUOUS_COLS

    # select features
    train_df = train_df.select(*(all_cols + ["Sales", "Date"])).cache()
    test_df = test_df.select(*(all_cols + ["Id", "Date"])).cache()

    # build a vocabulary of categorical columns
    vocab = build_vocabulary(
        train_df.select(*CATEGORICAL_COLS).unionAll(
            test_df.select(*CATEGORICAL_COLS)).cache(), )

    # cast continuous columns to float
    train_df = cast_columns(train_df, CONTINUOUS_COLS + ["Sales"])
    # search for a key and return a list of values based on a key
    train_df = lookup_columns(train_df, vocab)
    test_df = cast_columns(test_df, CONTINUOUS_COLS)
    test_df = lookup_columns(test_df, vocab)

    # split into training & validation
    # test set is in 2015, use the same period in 2014 from the training set as a validation set
    test_min_date = test_df.agg(F.min(test_df.Date)).collect()[0][0]
    test_max_date = test_df.agg(F.max(test_df.Date)).collect()[0][0]
    one_year = datetime.timedelta(365)
    train_df = train_df.withColumn(
        "Validation",
        (train_df.Date > test_min_date - one_year)
        & (train_df.Date <= test_max_date - one_year),
    )

    # determine max Sales number
    max_sales = train_df.agg(F.max(train_df.Sales)).collect()[0][0]

    # convert Sales to log domain
    train_df = train_df.withColumn("Sales", F.log(train_df.Sales))

    print("===================================")
    print("Data frame with transformed columns")
    print("===================================")
    train_df.show()

    print("================")
    print("Data frame sizes")
    print("================")

    # filter out column validation from the DataFrame, and get the count
    train_rows = train_df.filter(~train_df.Validation).count()
    val_rows = train_df.filter(train_df.Validation).count()
    test_rows = test_df.count()

    # print the number of rows in training, validation and test data
    print("Training: %d" % train_rows)
    print("Validation: %d" % val_rows)
    print("Test: %d" % test_rows)

    return max_sales, vocab, train_df, test_df
Пример #29
0
 def my_spark(df: pyspark.sql.DataFrame) -> my_schema:
     session = flytekit.current_context().spark_session
     new_df = session.createDataFrame([("Bob", 10)], my_schema.column_names())
     return df.union(new_df)
Пример #30
0
 def my_spark(a: int) -> my_schema:
     session = flytekit.current_context().spark_session
     df = session.createDataFrame([("Alice", a)], my_schema.column_names())
     print(type(df))
     return df