예제 #1
0
def user_info_task() -> (str, str):
    secret_username = flytekit.current_context().secrets.get(
        SECRET_GROUP, USERNAME_SECRET)
    secret_pwd = flytekit.current_context().secrets.get(
        SECRET_GROUP, PASSWORD_SECRET)
    # Please do not print the secret value, this is just a demonstration.
    print(f"{secret_username}={secret_pwd}")
    return secret_username, secret_pwd
def download_data(dataset: str) -> FlyteDirectory:
    # create a directory named 'data'
    print("==============")
    print("Downloading data")
    print("==============")

    working_dir = flytekit.current_context().working_directory
    data_dir = pathlib.Path(os.path.join(working_dir, "data"))
    data_dir.mkdir(exist_ok=True)

    # download the dataset
    download_subprocess = subprocess.run(
        [
            "curl",
            dataset,
        ],
        check=True,
        capture_output=True,
    )

    # untar the data
    subprocess.run(
        [
            "tar",
            "-xz",
            "-C",
            data_dir,
        ],
        input=download_subprocess.stdout,
    )

    # return the directory populated with Rossmann data files
    return FlyteDirectory(path=str(data_dir))
예제 #3
0
def create_entities() -> Tuple[FlyteFile, FlyteDirectory]:
    working_dir = flytekit.current_context().working_directory
    flytefile = os.path.join(working_dir, "test.txt")
    os.open(flytefile, os.O_CREAT)
    flytedir = os.path.join(working_dir, "testdata")
    os.makedirs(flytedir, exist_ok=True)
    return flytefile, flytedir
예제 #4
0
def secret_file_task() -> (str, str):
    # SM here is a handle to the secrets manager
    sm = flytekit.current_context().secrets
    f = sm.get_secrets_file(SECRET_GROUP, SECRET_NAME)
    secret_val = sm.get(SECRET_GROUP, SECRET_NAME)
    # returning the filename and the secret_val
    return f, secret_val
예제 #5
0
파일: shell.py 프로젝트: flyteorg/flytekit
 def interpolate(
     self,
     tmpl: str,
     inputs: typing.Optional[typing.Dict[str, str]] = None,
     outputs: typing.Optional[typing.Dict[str, str]] = None,
 ) -> str:
     """
     Interpolate python formatted string templates with variables from the input and output
     argument dicts. The result is non destructive towards the given template string.
     """
     inputs = inputs or {}
     outputs = outputs or {}
     inputs = AttrDict(inputs)
     outputs = AttrDict(outputs)
     consolidated_args = {
         "inputs": inputs,
         "outputs": outputs,
         "ctx": flytekit.current_context(),
     }
     try:
         return self._Formatter().format(tmpl, **consolidated_args)
     except KeyError as e:
         raise ValueError(
             f"Variable {e} in Query not found in inputs {consolidated_args.keys()}"
         )
예제 #6
0
def hello_spark(partitions: int) -> float:
    print("Starting Spark with Partitions: {}".format(partitions))

    n = 100000 * partitions
    sess = flytekit.current_context().spark_session
    count = sess.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
    pi_val = 4.0 * count / n
    print("Pi val is :{}".format(pi_val))
    return pi_val
예제 #7
0
def download_files() -> FlyteDirectory:
    working_dir = flytekit.current_context().working_directory
    pp = pathlib.Path(os.path.join(working_dir, "images"))
    pp.mkdir(exist_ok=True)
    for idx, remote_location in enumerate(default_images):
        local_image = os.path.join(working_dir, "images", f"image_{idx}.jpg")
        urllib.request.urlretrieve(remote_location, local_image)

    return FlyteDirectory(path=os.path.join(working_dir, "images"))
예제 #8
0
def create_spark_df() -> my_schema:
    """
    This spark program returns a spark dataset that conforms to the defined schema. Failure to do so should result
    in a runtime error. TODO: runtime error enforcement
    """
    sess = flytekit.current_context().spark_session
    return sess.createDataFrame(
        [("Alice", 5), ("Bob", 10), ("Charlie", 15), ], my_schema.column_names(),
    )
예제 #9
0
파일: task.py 프로젝트: jaychia/flytekit
 def execute(self, **kwargs) -> typing.Any:
     if self._secret_connect_args is not None:
         for key, secret in self._secret_connect_args.items():
             value = current_context().secrets.get(secret.group, secret.key)
             self._connect_args[key] = value
     engine = create_engine(self._uri,
                            connect_args=self._connect_args,
                            echo=False)
     print(f"Connecting to db {self._uri}")
     with engine.begin() as connection:
         df = pd.read_sql_query(self.get_query(**kwargs), connection)
     return df
def mnist_pytorch_job(hp: Hyperparameters) -> PythonPickledFile:
    # pytorch's save() function does not create a path if the path specified does not exist
    # therefore we must pass an existing path

    ctx = flytekit.current_context()
    data_dir = os.path.join(ctx.working_directory, "data")
    model_dir = os.path.join(ctx.working_directory, "model")
    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    args = TrainingArgs(
        hosts=ctx.distributed_training_context.hosts,
        current_host=ctx.distributed_training_context.current_host,
        num_gpus=torch.cuda.device_count(),
        batch_size=hp.batch_size,
        test_batch_size=hp.test_batch_size,
        epochs=hp.epochs,
        learning_rate=hp.learning_rate,
        sgd_momentum=hp.sgd_momentum,
        seed=hp.seed,
        log_interval=hp.log_interval,
        backend=hp.backend,
        data_dir=data_dir,
        model_dir=model_dir,
    )

    # Data shouldn't be downloaded by the functions called in mp.spawn due to race conditions
    # These can be replaced by Flyte's blob type inputs. Note that the data here are assumed
    # to be accessible via a local path:
    download_training_data(args.data_dir)
    download_test_data(args.data_dir)

    if len(args.hosts) > 1:
        # Config MASTER_ADDR and MASTER_PORT for PyTorch Distributed Training
        os.environ["MASTER_ADDR"] = args.hosts[0]
        os.environ["MASTER_PORT"] = "29500"
        os.environ[
            "NCCL_SOCKET_IFNAME"] = ctx.distributed_training_context.network_interface_name
        os.environ["NCCL_DEBUG"] = "INFO"
        # The function is called as fn(i, *args), where i is the process index and args is the passed
        # through tuple of arguments.
        # https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
        mp.spawn(train, nprocs=args.num_gpus, args=(args, ))
    else:
        # Config for Multi GPU with a single instance training
        if args.num_gpus > 1:
            gpu_devices = ",".join(
                [str(gpu_id) for gpu_id in range(args.num_gpus)])
            os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
        train(-1, args)

    pth = os.path.join(model_dir, "model.pth")
    print(f"Returning model @ {pth}")
    return pth
예제 #11
0
    def t1() -> FlyteDirectory:
        user_ctx = flytekit.current_context()
        # Create a local directory to work with
        p = os.path.join(user_ctx.working_directory, "test_wf")
        if os.path.exists(p):
            shutil.rmtree(p)
        pathlib.Path(p).mkdir(parents=True)
        for i in range(1, 6):
            with open(os.path.join(p, f"{i}.txt"), "w") as fh:
                fh.write(f"I'm file {i}\n")

        return FlyteDirectory(p)
예제 #12
0
def download_files(csv_urls: List[str]) -> FlyteDirectory:
    working_dir = flytekit.current_context().working_directory
    local_dir = Path(os.path.join(working_dir, "csv_files"))
    local_dir.mkdir(exist_ok=True)

    # get the number of digits needed to preserve the order of files in the local directory
    zfill_len = len(str(len(csv_urls)))
    for idx, remote_location in enumerate(csv_urls):
        local_image = os.path.join(
            # prefix the file name with the index location of the file in the original csv_urls list
            local_dir,
            f"{str(idx).zfill(zfill_len)}_{os.path.basename(remote_location)}",
        )
        urllib.request.urlretrieve(remote_location, local_image)
    return FlyteDirectory(path=str(local_dir))
예제 #13
0
    def onnx_predict(model_file: ONNXFile) -> JPEGImageFile:
        ort_session = onnxruntime.InferenceSession(model_file.download())

        img = Image.open(
            requests.get(
                "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/cat.jpg",
                stream=True).raw)

        resize = transforms.Resize([224, 224])
        img = resize(img)

        img_ycbcr = img.convert("YCbCr")
        img_y, img_cb, img_cr = img_ycbcr.split()

        to_tensor = transforms.ToTensor()
        img_y = to_tensor(img_y)
        img_y.unsqueeze_(0)

        # compute ONNX Runtime output prediction
        ort_inputs = {
            ort_session.get_inputs()[0].name:
            img_y.detach().cpu().numpy()
            if img_y.requires_grad else img_y.cpu().numpy()
        }
        ort_outs = ort_session.run(None, ort_inputs)
        img_out_y = ort_outs[0]

        img_out_y = Image.fromarray(np.uint8(
            (img_out_y[0] * 255.0).clip(0, 255)[0]),
                                    mode="L")

        # get the output image follow post-processing step from PyTorch implementation
        final_img = Image.merge(
            "YCbCr",
            [
                img_out_y,
                img_cb.resize(img_out_y.size, Image.BICUBIC),
                img_cr.resize(img_out_y.size, Image.BICUBIC),
            ],
        ).convert("RGB")

        img_path = Path(flytekit.current_context().working_directory
                        ) / "cat_superres_with_ort.jpg"
        final_img.save(img_path)

        # Save the image, we will compare this with the output image from mobile device
        return JPEGImageFile(path=str(img_path))
예제 #14
0
    def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
        """
        In the case of distributed execution, we check the should_persist_predicate in the configuration to determine
        if the output should be persisted. This is because in distributed training, multiple nodes may produce partial
        outputs and only the user process knows the output that should be generated. They can control the choice using
        the predicate.

        To control if output is generated across every execution, we override the post_execute method and sometimes
        return a None
        """
        if self._is_distributed():
            logging.info("Distributed context detected!")
            dctx = flytekit.current_context().distributed_training_context
            if not self.task_config.should_persist_output(dctx):
                logging.info("output persistence predicate not met, Flytekit will ignore outputs")
                raise IgnoreOutputs(f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}")
        return rval
예제 #15
0
def rotate(image_location: str) -> FlyteFile:
    """
    Download the given image, rotate it by 180 degrees
    """
    working_dir = flytekit.current_context().working_directory
    local_image = os.path.join(working_dir, "incoming.jpg")
    urllib.request.urlretrieve(image_location, local_image)
    img = cv2.imread(local_image, 0)
    if img is None:
        raise Exception("Failed to read image")
    (h, w) = img.shape[:2]
    center = (w / 2, h / 2)
    mat = cv2.getRotationMatrix2D(center, 180, 1)
    res = cv2.warpAffine(img, mat, (w, h))
    out_path = os.path.join(working_dir, "rotated.jpg")
    cv2.imwrite(out_path, res)
    return FlyteFile["jpg"](path=out_path)
def horovod_spark_task(data_dir: FlyteDirectory, hp: Hyperparameters,
                       work_dir: FlyteDirectory) -> FlyteDirectory:

    max_sales, vocab, train_df, test_df = data_preparation(data_dir, hp)

    # working directory will have the model and predictions as separate files
    working_dir = flytekit.current_context().working_directory

    keras_model = train(
        max_sales,
        vocab,
        hp,
        work_dir,
        train_df,
        working_dir,
    )

    # generate predictions
    return test(keras_model, working_dir, test_df, hp)
예제 #17
0
def test_secrets():
    @task(secret_requests=[Secret("my_group", "my_key")])
    def foo() -> str:
        return flytekit.current_context().secrets.get("my_group", "")

    with pytest.raises(ValueError):
        foo()

    @task(secret_requests=[Secret("group", group_version="v1", key="key")])
    def foo2() -> str:
        return flytekit.current_context().secrets.get("group", "key")

    os.environ[flytekit.current_context().secrets.get_secrets_env_var("group", "key")] = "super-secret-value2"
    assert foo2() == "super-secret-value2"

    with pytest.raises(AssertionError):

        @task(secret_requests=["test"])
        def foo() -> str:
            pass
예제 #18
0
def fit(loc: str, train: pd.DataFrame, val: pd.DataFrame) -> JoblibSerializedFile:

    # fetch the features and target columns from the train dataset
    x = train[train.columns[1:]]
    y = train[train.columns[0]]

    # fetch the features and target columns from the validation dataset
    eval_x = val[val.columns[1:]]
    eval_y = val[val.columns[0]]

    m = XGBRegressor()
    # fit the model to the train data
    m.fit(x, y, eval_set=[(eval_x, eval_y)])

    working_dir = flytekit.current_context().working_directory
    fname = os.path.join(working_dir, f"model-{loc}.joblib.dat")
    joblib.dump(m, fname)

    # return the serialized model
    return JoblibSerializedFile(path=fname)
예제 #19
0
def rotate(image_location: JPEGImageFile, location: str) -> JPEGImageFile:
    """
    Download the given image, rotate it by 180 degrees
    """
    working_dir = flytekit.current_context().working_directory
    image_location.download()
    img = cv2.imread(image_location.path, 0)
    if img is None:
        raise Exception("Failed to read image")
    (h, w) = img.shape[:2]
    center = (w / 2, h / 2)
    mat = cv2.getRotationMatrix2D(center, 180, 1)
    res = cv2.warpAffine(img, mat, (w, h))
    out_path = os.path.join(
        working_dir,
        f"rotated-{os.path.basename(image_location.path).rsplit('.')[0]}.jpg",
    )
    cv2.imwrite(out_path, res)
    if location:
        return JPEGImageFile(path=out_path, remote_path=location)
    else:
        return JPEGImageFile(path=out_path)
예제 #20
0
def use_checkpoint(n_iterations: int) -> int:
    cp = current_context().checkpoint
    prev = cp.read()
    start = 0
    if prev:
        start = int(prev.decode())

    # create a failure interval so we can create failures for every 'n' iterations and then succeed within
    # configured retries
    failure_interval = n_iterations * 1.0 / RETRIES
    i = 0
    for i in range(start, n_iterations):
        # simulate a deterministic failure, for demonstration. We want to show how it eventually completes within
        # the given retries
        if i > start and i % failure_interval == 0:
            raise FlyteRecoverableException(
                f"Failed at iteration {start}, failure_interval {failure_interval}"
            )
        # save progress state. It is also entirely possible save state every few intervals.
        cp.write(f"{i + 1}".encode())

    return i
예제 #21
0
def normalize_columns(
    csv_url: FlyteFile,
    column_names: List[str],
    columns_to_normalize: List[str],
    output_location: str,
) -> FlyteFile:
    # read the data from the raw csv file
    parsed_data = defaultdict(list)
    with open(csv_url, newline="\n") as input_file:
        reader = csv.DictReader(input_file, fieldnames=column_names)
        for row in (x for i, x in enumerate(reader) if i > 0):
            for column in columns_to_normalize:
                parsed_data[column].append(float(row[column].strip()))

    # normalize the data
    normalized_data = defaultdict(list)
    for colname, values in parsed_data.items():
        mean = sum(values) / len(values)
        std = (sum([(x - mean)**2 for x in values]) / len(values))**0.5
        normalized_data[colname] = [(x - mean) / std for x in values]

    # write to local path
    out_path = os.path.join(
        flytekit.current_context().working_directory,
        f"normalized-{os.path.basename(csv_url.path).rsplit('.')[0]}.csv",
    )
    with open(out_path, mode="w") as output_file:
        writer = csv.DictWriter(output_file, fieldnames=columns_to_normalize)
        writer.writeheader()
        for row in zip(*normalized_data.values()):
            writer.writerow(
                {k: row[i]
                 for i, k in enumerate(columns_to_normalize)})

    if output_location:
        return FlyteFile(path=out_path, remote_path=output_location)
    else:
        return FlyteFile(path=out_path)
예제 #22
0
def test_input_substitution_files_ctx():
    sec = flytekit.current_context().secrets
    envvar = sec.get_secrets_env_var("group", "key")
    os.environ[envvar] = "value"
    assert sec.get("group", "key") == "value"

    t = ShellTask(
        name="test",
        script="""
            export EXEC={ctx.execution_id}
            export SECRET={ctx.secrets.group.key}
            cat {inputs.f}
            echo "Hello World {inputs.y} on  {inputs.j}"
            """,
        inputs=kwtypes(f=CSVFile, y=FlyteDirectory, j=datetime.datetime),
        debug=True,
    )

    if os.name == "nt":
        t._script = t._script.replace("cat", "type").replace("export", "set")

    assert t(f=test_csv, y=testdata, j=datetime.datetime(2021, 11, 10, 12, 15, 0)) is None
    del os.environ[envvar]
예제 #23
0
파일: task.py 프로젝트: flyteorg/flytekit
    def execute_from_model(self, tt: task_models.TaskTemplate,
                           **kwargs) -> typing.Any:
        if tt.custom["secret_connect_args"] is not None:
            for key, secret_dict in tt.custom["secret_connect_args"].items():
                value = current_context().secrets.get(
                    group=secret_dict["group"], key=secret_dict["key"])
                tt.custom["connect_args"][key] = value

        engine = create_engine(tt.custom["uri"],
                               connect_args=tt.custom["connect_args"],
                               echo=False)
        print(f"Connecting to db {tt.custom['uri']}")

        interpolated_query = SQLAlchemyTask.interpolate_query(
            tt.custom["query_template"], **kwargs)
        print(f"Interpolated query {interpolated_query}")
        with engine.begin() as connection:
            df = None
            if tt.interface.outputs:
                df = pd.read_sql_query(interpolated_query, connection)
            else:
                pandasSQL_builder(connection).execute(interpolated_query)
        return df
예제 #24
0
def horovod_train_task(batch_size: int, buffer_size: int,
                       dataset_size: int) -> FlyteDirectory:
    """
    :param batch_size: Represents the number of consecutive elements of this dataset to combine in a single batch.
    :param buffer_size: Defines the size of the buffer used to hold elements of the dataset used for training.
    :param dataset_size: The number of elements of this dataset that should be taken to form the new dataset when
        running batched training.
    """
    hvd.init()

    (mnist_images,
     mnist_labels), _ = tf.keras.datasets.mnist.load_data(path="mnist-%d.npz" %
                                                          hvd.rank())

    dataset = tf.data.Dataset.from_tensor_slices((
        tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
        tf.cast(mnist_labels, tf.int64),
    ))
    dataset = dataset.repeat().shuffle(buffer_size).batch(batch_size)

    mnist_model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, [3, 3], activation="relu"),
        tf.keras.layers.Conv2D(64, [3, 3], activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(10, activation="softmax"),
    ])
    loss = tf.losses.SparseCategoricalCrossentropy()

    # Horovod: adjust learning rate based on number of GPUs.
    opt = tf.optimizers.Adam(0.001 * hvd.size())

    checkpoint_dir = ".checkpoint"
    pathlib.Path(checkpoint_dir).mkdir(exist_ok=True)

    checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)

    # Horovod: adjust number of steps based on number of GPUs.
    for batch, (images,
                labels) in enumerate(dataset.take(dataset_size // hvd.size())):
        loss_value = training_step(images, labels, batch == 0, mnist_model,
                                   loss, opt)

        if batch % 10 == 0 and hvd.local_rank() == 0:
            print("Step #%d\tLoss: %.6f" % (batch, loss_value))

    if hvd.rank() != 0:
        raise IgnoreOutputs("I am not rank 0")

    working_dir = flytekit.current_context().working_directory
    checkpoint_prefix = pathlib.Path(os.path.join(working_dir, "checkpoint"))
    checkpoint.save(checkpoint_prefix)

    tf.keras.models.save_model(
        mnist_model,
        str(working_dir),
        overwrite=True,
        include_optimizer=True,
        save_format=None,
        signatures=None,
        options=None,
        save_traces=True,
    )
    return FlyteDirectory(path=str(working_dir))
예제 #25
0
def secret_task() -> str:
    secret_val = flytekit.current_context().secrets.get(
        SECRET_GROUP, SECRET_NAME)
    # Please do not print the secret value, we are doing so just as a demonstration
    print(secret_val)
    return secret_val
예제 #26
0
        y=alt.Y('new_cases_smoothed_per_million:Q', stack=None),
        color=alt.Color('continent:N', scale=alt.Scale(scheme='set1')),
        tooltip='continent:N').interactive().properties(width='container')

    dp.Report(dp.Plot(plot), dp.DataTable(df)).save(path='report.html',open=True)


@task
def transform_data(url: str) -> pandas.DataFrame:
    dataset = pd.read_csv(url)
    df = dataset.groupby(
        ['continent',
         'date'])['new_cases_smoothed_per_million'].mean().reset_index()
    return df


@workflow
def datapane_workflow(url: str):
    df = transform_data(url=url)
    publish_report(df=df)
    print(f"Report is published for {url}")


default_lp = LaunchPlan.get_default_launch_plan(
    current_context(),
    datapane_workflow)

if __name__ == "__main__":
    print(default_lp(url="https://covid.ourworldindata.org/data/owid-covid-data.csv"))

예제 #27
0
def t1(n: int) -> int:
    ctx = flytekit.current_context()
    cp = ctx.checkpoint
    cp.write(bytes(n + 1))
    return n + 1
def data_preparation(
    data_dir: FlyteDirectory, hp: Hyperparameters
) -> Tuple[float, Dict[str, List[Any]], pyspark.sql.DataFrame,
           pyspark.sql.DataFrame]:
    print("================")
    print("Data preparation")
    print("================")

    # 'current_context' gives the handle of specific parameters in ``data_preparation`` task
    spark = flytekit.current_context().spark_session
    data_dir_path = data_dir.remote_source
    # read the CSV data into Spark DataFrame
    train_csv = spark.read.csv("%s/train.csv" % data_dir_path, header=True)
    test_csv = spark.read.csv("%s/test.csv" % data_dir_path, header=True)

    store_csv = spark.read.csv("%s/store.csv" % data_dir_path, header=True)
    store_states_csv = spark.read.csv("%s/store_states.csv" % data_dir_path,
                                      header=True)
    state_names_csv = spark.read.csv("%s/state_names.csv" % data_dir_path,
                                     header=True)
    google_trend_csv = spark.read.csv("%s/googletrend.csv" % data_dir_path,
                                      header=True)
    weather_csv = spark.read.csv("%s/weather.csv" % data_dir_path, header=True)

    # retrieve a sampled subset of the train and test data
    if hp.sample_rate:
        train_csv = train_csv.sample(withReplacement=False,
                                     fraction=hp.sample_rate)
        test_csv = test_csv.sample(withReplacement=False,
                                   fraction=hp.sample_rate)

    # prepare the DataFrames from the CSV files
    train_df = prepare_df(
        train_csv,
        store_csv,
        store_states_csv,
        state_names_csv,
        google_trend_csv,
        weather_csv,
    ).cache()
    test_df = prepare_df(
        test_csv,
        store_csv,
        store_states_csv,
        state_names_csv,
        google_trend_csv,
        weather_csv,
    ).cache()

    # add elapsed times from the data spanning training & test datasets
    elapsed_cols = ["Promo", "StateHoliday", "SchoolHoliday"]
    elapsed = add_elapsed(
        train_df.select("Date", "Store", *elapsed_cols).unionAll(
            test_df.select("Date", "Store", *elapsed_cols)),
        elapsed_cols,
    )

    # join with the elapsed times
    train_df = train_df.join(elapsed, ["Date", "Store"]).select(
        train_df["*"],
        *[
            prefix + col for prefix in ["Before", "After"]
            for col in elapsed_cols
        ],
    )
    test_df = test_df.join(elapsed, ["Date", "Store"]).select(
        test_df["*"],
        *[
            prefix + col for prefix in ["Before", "After"]
            for col in elapsed_cols
        ],
    )

    # filter out zero sales
    train_df = train_df.filter(train_df.Sales > 0)

    print("===================")
    print("Prepared data frame")
    print("===================")
    train_df.show()

    all_cols = CATEGORICAL_COLS + CONTINUOUS_COLS

    # select features
    train_df = train_df.select(*(all_cols + ["Sales", "Date"])).cache()
    test_df = test_df.select(*(all_cols + ["Id", "Date"])).cache()

    # build a vocabulary of categorical columns
    vocab = build_vocabulary(
        train_df.select(*CATEGORICAL_COLS).unionAll(
            test_df.select(*CATEGORICAL_COLS)).cache(), )

    # cast continuous columns to float
    train_df = cast_columns(train_df, CONTINUOUS_COLS + ["Sales"])
    # search for a key and return a list of values based on a key
    train_df = lookup_columns(train_df, vocab)
    test_df = cast_columns(test_df, CONTINUOUS_COLS)
    test_df = lookup_columns(test_df, vocab)

    # split into training & validation
    # test set is in 2015, use the same period in 2014 from the training set as a validation set
    test_min_date = test_df.agg(F.min(test_df.Date)).collect()[0][0]
    test_max_date = test_df.agg(F.max(test_df.Date)).collect()[0][0]
    one_year = datetime.timedelta(365)
    train_df = train_df.withColumn(
        "Validation",
        (train_df.Date > test_min_date - one_year)
        & (train_df.Date <= test_max_date - one_year),
    )

    # determine max Sales number
    max_sales = train_df.agg(F.max(train_df.Sales)).collect()[0][0]

    # convert Sales to log domain
    train_df = train_df.withColumn("Sales", F.log(train_df.Sales))

    print("===================================")
    print("Data frame with transformed columns")
    print("===================================")
    train_df.show()

    print("================")
    print("Data frame sizes")
    print("================")

    # filter out column validation from the DataFrame, and get the count
    train_rows = train_df.filter(~train_df.Validation).count()
    val_rows = train_df.filter(train_df.Validation).count()
    test_rows = test_df.count()

    # print the number of rows in training, validation and test data
    print("Training: %d" % train_rows)
    print("Validation: %d" % val_rows)
    print("Test: %d" % test_rows)

    return max_sales, vocab, train_df, test_df
예제 #29
0
 def my_spark(df: pyspark.sql.DataFrame) -> my_schema:
     session = flytekit.current_context().spark_session
     new_df = session.createDataFrame([("Bob", 10)], my_schema.column_names())
     return df.union(new_df)
예제 #30
0
 def my_spark(a: int) -> my_schema:
     session = flytekit.current_context().spark_session
     df = session.createDataFrame([("Alice", a)], my_schema.column_names())
     print(type(df))
     return df