def main() -> None:
    args = get_args()

    os.makedirs(args.data_dir, exist_ok=True)
    os.makedirs(args.out_dir, exist_ok=True)

    n_epochs = args.n_epochs
    batch_size_train = args.batch_size_train
    batch_size_test = args.batch_size_test
    learning_rate = args.learning_rate
    momentum = args.momentum
    log_interval = args.log_interval

    random_seed = 1  # this should be a random func in non-demo code
    torch.backends.cudnn.enabled = False
    torch.manual_seed(random_seed)

    # Create loaders for the training and test data

    train_loader = mnist_data_loader(args.data_dir, batch_size_train)
    test_loader = mnist_data_loader(args.data_dir,
                                    batch_size_test,
                                    is_training=False)

    example_data = save_example_training_data(train_loader)

    # Train our model and test every epoch

    network = Net()
    optimizer = optim.SGD(network.parameters(),
                          lr=learning_rate,
                          momentum=momentum)

    train_losses = []
    train_counter = []
    test_losses = []
    test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]

    training_state = TrainingState(train_losses, train_counter, args.out_dir)
    test(test_loader, network, test_losses)
    for epoch in range(1, n_epochs + 1):
        train(train_loader, network, optimizer, epoch, log_interval,
              training_state)
        test(test_loader, network, test_losses)

    save_loss_data_file(args.out_dir, train_counter, train_losses,
                        test_counter, test_losses, n_epochs)

    save_example_prediction_data(args.out_dir, network, example_data)
Esempio n. 2
0
    async def print(self, data: Dict[str, str]) -> SocketMessageResponse:
        log().info("printing...")
        if 'file' not in data:
            return SocketMessageResponse(1, "file not specified")

        if self.actualState['download']['file'] is not None:
            return SocketMessageResponse(
                1, "file " + self.actualState['download']['file'] +
                " has already been sheduled to download and print")

        if not self.actualState["status"]["state"]['text'] == 'Operational':
            return SocketMessageResponse(
                1, "pandora is not in an operational state")

        upload_path = get_args().octoprint_upload_path
        if not os.path.isdir(upload_path):
            os.mkdir(upload_path)
        gcode = upload_path + '/' + (data['file'] if data['file'].endswith(
            '.gcode') else data['file'] + '.gcode')

        if not os.path.isfile(gcode):
            log().info("file " + gcode + " not found, downloading it...")

            async def download_and_print():
                self.actualState["download"]["file"] = data['file']
                self.actualState["download"]["completion"] = 0.0
                r = await self.ulabapi.download(data['file'])

                if not r.status == 200:
                    log().warning("error downloading file " + data['file'] +
                                  " from url: " + str(r.status))
                    self.actualState["download"]["file"] = None
                    self.actualState["download"]["completion"] = -1

                await self._download_file(r, gcode)
                await self._print_file(gcode)

            asyncio.get_running_loop().create_task(download_and_print(
            ))  # todo: get running loop from somewhere cleaner
            return SocketMessageResponse(
                0, "file was not on ucloud, downloading it and printing it...")

        await self._print_file(gcode)
        return SocketMessageResponse(0, "ok")
Esempio n. 3
0
from lib.args import get_args
from lib.io import load, save
from lib import objectives

if __name__ == '__main__':
    # Get commandline args from args.py
    args = get_args()

    if args.dataset == 'reuters8':
        features, labels = load.reuters8()
    elif args.dataset == 'classic4':
        features, labels = load.classic4()
    elif args.dataset == 'ng20':
        features, labels = load.ng20()
    elif args.dataset == 'webkb':
        features, labels = load.webkb()
    else:
        raise Exception('Unknown dataset')

    if args.save_dense_matrix:
        save.dense_matrix(features, labels, args.dataset)
    if args.save_sparse_matrix:
        save.sparse_matrix(features, labels, args.dataset)

    if args.objective == 'I1':
        objective_value = objectives.I1(features, labels)
    elif args.objective == 'I2':
        objective_value = objectives.I2(features, labels)
    elif args.objective == 'E1':
        objective_value = objectives.E1(features, labels)
    elif args.objective == 'H1':
import sys
from pyspark.sql.functions import col as sql_col, lit
from pyspark.sql.types import TimestampType, BooleanType, StringType

from lib.args import get_args
from lib.constants import CHANGES_METADATA_OPERATION, CHANGES_METADATA_TIMESTAMP
from lib.metadata import get_batch_metadata, get_metadata_file_list
from lib.spark import get_spark
from lib.table import process_special_fields, get_delta_table

cmd_args = get_args()
spark = get_spark()

# List "change" files
print(f">>> Searching for batch metadata files in: {cmd_args.changes_path}...")
dfm_files = get_metadata_file_list(cmd_args.changes_path)
if not dfm_files:
    print(">>> Nothing to-do, exiting...")
    sys.exit(0)

# Get batch metadata and validate columns
print(f">>> Found {len(dfm_files)} batch metadata files, loading metadata...")
batch = get_batch_metadata(dfm_files=dfm_files,
                           src_path_override=cmd_args.changes_path)
print(
    f">>> Metadata loaded, num_files={len(batch.files)}, records={batch.record_count}"
)
if not batch.files:
    raise Exception("Did not found any files to load..")

if len(batch.primary_key_columns) > 1: