Ejemplo n.º 1
0
def upload_checkpoint(return_page: str):
    """
    Uploads a checkpoint .pt file.

    :param return_page: The name of the page to render after uploading the checkpoint file.
    """
    warnings, errors = [], []

    current_user = request.cookies.get('currentUser')

    if not current_user:
        # Use DEFAULT as current user if the client's cookie is not set.
        current_user = app.config['DEFAULT_USER_ID']

    ckpt = request.files['checkpoint']

    ckpt_name = request.form['checkpointName']
    ckpt_ext = os.path.splitext(ckpt.filename)[1]

    # Collect paths to all uploaded checkpoints (and unzip if necessary)
    temp_dir = TemporaryDirectory()
    ckpt_paths = []

    if ckpt_ext.endswith('.pt'):
        ckpt_path = os.path.join(temp_dir.name, 'model.pt')
        ckpt.save(ckpt_path)
        ckpt_paths = [ckpt_path]

    elif ckpt_ext.endswith('.zip'):
        ckpt_dir = os.path.join(temp_dir.name, 'models')
        zip_path = os.path.join(temp_dir.name, 'models.zip')
        ckpt.save(zip_path)

        with zipfile.ZipFile(zip_path, mode='r') as z:
            z.extractall(ckpt_dir)

        for root, _, fnames in os.walk(ckpt_dir):
            ckpt_paths += [
                os.path.join(root, fname) for fname in fnames
                if fname.endswith('.pt')
            ]

    else:
        errors.append(
            f'Uploaded checkpoint(s) file must be either .pt or .zip but got {ckpt_ext}'
        )

    # Insert checkpoints into database
    if len(ckpt_paths) > 0:
        ckpt_args = load_args(ckpt_paths[0])
        ckpt_id, new_ckpt_name = db.insert_ckpt(ckpt_name, current_user,
                                                ckpt_args.dataset_type,
                                                ckpt_args.epochs,
                                                len(ckpt_paths),
                                                ckpt_args.train_data_size)

        for ckpt_path in ckpt_paths:
            model_id = db.insert_model(ckpt_id)
            model_path = os.path.join(app.config['CHECKPOINT_FOLDER'],
                                      f'{model_id}.pt')

            if ckpt_name != new_ckpt_name:
                warnings.append(
                    name_already_exists_message('Checkpoint', ckpt_name,
                                                new_ckpt_name))

            shutil.copy(ckpt_path, model_path)

    temp_dir.cleanup()

    warnings, errors = json.dumps(warnings), json.dumps(errors)

    return redirect(
        url_for(return_page,
                checkpoint_upload_warnings=warnings,
                checkpoint_upload_errors=errors))
Ejemplo n.º 2
0
def train():
    """Renders the train page and performs training if request method is POST."""
    global PROGRESS, TRAINING

    warnings, errors = [], []

    if request.method == 'GET':
        return render_train()

    # Get arguments
    data_name, epochs, ensemble_size, checkpoint_name = \
        request.form['dataName'], int(request.form['epochs']), \
        int(request.form['ensembleSize']), request.form['checkpointName']
    gpu = request.form.get('gpu')
    data_path = os.path.join(app.config['DATA_FOLDER'], f'{data_name}.csv')
    dataset_type = request.form.get('datasetType', 'regression')

    # Create and modify args
    args = TrainArgs().parse_args([
        '--data_path', data_path, '--dataset_type', dataset_type, '--epochs',
        str(epochs), '--ensemble_size',
        str(ensemble_size)
    ])

    # Check if regression/classification selection matches data
    data = get_data(path=data_path)
    targets = data.targets()
    unique_targets = {
        target
        for row in targets for target in row if target is not None
    }

    if dataset_type == 'classification' and len(unique_targets - {0, 1}) > 0:
        errors.append(
            'Selected classification dataset but not all labels are 0 or 1. Select regression instead.'
        )

        return render_train(warnings=warnings, errors=errors)

    if dataset_type == 'regression' and unique_targets <= {0, 1}:
        errors.append(
            'Selected regression dataset but all labels are 0 or 1. Select classification instead.'
        )

        return render_train(warnings=warnings, errors=errors)

    if gpu is not None:
        if gpu == 'None':
            args.cuda = False
        else:
            args.gpu = int(gpu)

    current_user = request.cookies.get('currentUser')

    if not current_user:
        # Use DEFAULT as current user if the client's cookie is not set.
        current_user = app.config['DEFAULT_USER_ID']

    ckpt_id, ckpt_name = db.insert_ckpt(checkpoint_name, current_user,
                                        args.dataset_type, args.epochs,
                                        args.ensemble_size, len(targets))

    with TemporaryDirectory() as temp_dir:
        args.save_dir = temp_dir

        process = mp.Process(target=progress_bar, args=(args, PROGRESS))
        process.start()
        TRAINING = 1

        # Run training
        logger = create_logger(name='train',
                               save_dir=args.save_dir,
                               quiet=args.quiet)
        task_scores = run_training(args, logger)
        process.join()

        # Reset globals
        TRAINING = 0
        PROGRESS = mp.Value('d', 0.0)

        # Check if name overlap
        if checkpoint_name != ckpt_name:
            warnings.append(
                name_already_exists_message('Checkpoint', checkpoint_name,
                                            ckpt_name))

        # Move models
        for root, _, files in os.walk(args.save_dir):
            for fname in files:
                if fname.endswith('.pt'):
                    model_id = db.insert_model(ckpt_id)
                    save_path = os.path.join(app.config['CHECKPOINT_FOLDER'],
                                             f'{model_id}.pt')
                    shutil.move(os.path.join(args.save_dir, root, fname),
                                save_path)

    return render_train(trained=True,
                        metric=args.metric,
                        num_tasks=len(args.task_names),
                        task_names=args.task_names,
                        task_scores=format_float_list(task_scores),
                        mean_score=format_float(np.mean(task_scores)),
                        warnings=warnings,
                        errors=errors)