def upload_checkpoint(return_page: str): """ Uploads a checkpoint .pt file. :param return_page: The name of the page to render after uploading the checkpoint file. """ warnings, errors = [], [] current_user = request.cookies.get('currentUser') if not current_user: # Use DEFAULT as current user if the client's cookie is not set. current_user = app.config['DEFAULT_USER_ID'] ckpt = request.files['checkpoint'] ckpt_name = request.form['checkpointName'] # Create temporary file to get ckpt_args without losing data. with NamedTemporaryFile() as temp_file: ckpt.save(temp_file.name) ckpt_args = load_args(temp_file) ckpt_id, new_ckpt_name = db.insert_ckpt(ckpt_name, current_user, ckpt_args.dataset_type, ckpt_args.epochs, 1, ckpt_args.train_data_size) model_id = db.insert_model(ckpt_id) model_path = os.path.join(app.config['CHECKPOINT_FOLDER'], f'{model_id}.pt') if ckpt_name != new_ckpt_name: warnings.append( name_already_exists_message('Checkpoint', ckpt_name, new_ckpt_name)) shutil.copy(temp_file.name, model_path) warnings, errors = json.dumps(warnings), json.dumps(errors) return redirect( url_for(return_page, checkpoint_upload_warnings=warnings, checkpoint_upload_errors=errors))
def train(): """Renders the train page and performs training if request method is POST.""" global PROGRESS, TRAINING warnings, errors = [], [] if request.method == 'GET': return render_train() # Get arguments data_name, epochs, ensemble_size, checkpoint_name = \ request.form['dataName'], int(request.form['epochs']), \ int(request.form['ensembleSize']), request.form['checkpointName'] gpu = request.form.get('gpu') data_path = os.path.join(app.config['DATA_FOLDER'], f'{data_name}.csv') dataset_type = request.form.get('datasetType', 'regression') # Create and modify args parser = ArgumentParser() add_train_args(parser) args = parser.parse_args([]) args.data_path = data_path args.dataset_type = dataset_type args.epochs = epochs args.ensemble_size = ensemble_size # Check if regression/classification selection matches data data = get_data(path=data_path) targets = data.targets() unique_targets = {target for row in targets for target in row if target is not None} if dataset_type == 'classification' and len(unique_targets - {0, 1}) > 0: errors.append('Selected classification dataset but not all labels are 0 or 1. Select regression instead.') return render_train(warnings=warnings, errors=errors) if dataset_type == 'regression' and unique_targets <= {0, 1}: errors.append('Selected regression dataset but all labels are 0 or 1. Select classification instead.') return render_train(warnings=warnings, errors=errors) if gpu is not None: if gpu == 'None': args.no_cuda = True else: args.gpu = int(gpu) current_user = request.cookies.get('currentUser') if not current_user: # Use DEFAULT as current user if the client's cookie is not set. current_user = app.config['DEFAULT_USER_ID'] ckpt_id, ckpt_name = db.insert_ckpt(checkpoint_name, current_user, args.dataset_type, args.epochs, args.ensemble_size, len(targets)) with TemporaryDirectory() as temp_dir: args.save_dir = temp_dir modify_train_args(args) process = mp.Process(target=progress_bar, args=(args, PROGRESS)) process.start() TRAINING = 1 # Run training logger = create_logger(name='train', save_dir=args.save_dir, quiet=args.quiet) task_scores = run_training(args, logger) process.join() # Reset globals TRAINING = 0 PROGRESS = mp.Value('d', 0.0) # Check if name overlap if checkpoint_name != ckpt_name: warnings.append(name_already_exists_message('Checkpoint', checkpoint_name, ckpt_name)) # Move models for root, _, files in os.walk(args.save_dir): for fname in files: if fname.endswith('.pt'): model_id = db.insert_model(ckpt_id) save_path = os.path.join(app.config['CHECKPOINT_FOLDER'], f'{model_id}.pt') shutil.move(os.path.join(args.save_dir, root, fname), save_path) return render_train(trained=True, metric=args.metric, num_tasks=len(args.task_names), task_names=args.task_names, task_scores=format_float_list(task_scores), mean_score=format_float(np.mean(task_scores)), warnings=warnings, errors=errors)
def upload_checkpoint(return_page: str): """ Uploads a checkpoint .pt file. :param return_page: The name of the page to render after uploading the checkpoint file. """ warnings, errors = [], [] current_user = request.cookies.get('currentUser') if not current_user: # Use DEFAULT as current user if the client's cookie is not set. current_user = app.config['DEFAULT_USER_ID'] ckpt = request.files['checkpoint'] ckpt_name = request.form['checkpointName'] ckpt_ext = os.path.splitext(ckpt.filename)[1] # Collect paths to all uploaded checkpoints (and unzip if necessary) temp_dir = TemporaryDirectory() ckpt_paths = [] if ckpt_ext.endswith('.pt'): ckpt_path = os.path.join(temp_dir.name, 'model.pt') ckpt.save(ckpt_path) ckpt_paths = [ckpt_path] elif ckpt_ext.endswith('.zip'): ckpt_dir = os.path.join(temp_dir.name, 'models') zip_path = os.path.join(temp_dir.name, 'models.zip') ckpt.save(zip_path) with zipfile.ZipFile(zip_path, mode='r') as z: z.extractall(ckpt_dir) for root, _, fnames in os.walk(ckpt_dir): ckpt_paths += [ os.path.join(root, fname) for fname in fnames if fname.endswith('.pt') ] else: errors.append( f'Uploaded checkpoint(s) file must be either .pt or .zip but got {ckpt_ext}' ) # Insert checkpoints into database if len(ckpt_paths) > 0: ckpt_args = load_args(ckpt_paths[0]) ckpt_id, new_ckpt_name = db.insert_ckpt(ckpt_name, current_user, ckpt_args.dataset_type, ckpt_args.epochs, len(ckpt_paths), ckpt_args.train_data_size) for ckpt_path in ckpt_paths: model_id = db.insert_model(ckpt_id) model_path = os.path.join(app.config['CHECKPOINT_FOLDER'], f'{model_id}.pt') if ckpt_name != new_ckpt_name: warnings.append( name_already_exists_message('Checkpoint', ckpt_name, new_ckpt_name)) shutil.copy(ckpt_path, model_path) temp_dir.cleanup() warnings, errors = json.dumps(warnings), json.dumps(errors) return redirect( url_for(return_page, checkpoint_upload_warnings=warnings, checkpoint_upload_errors=errors))