Ejemplo n.º 1
0
def run_training(args, save_dir):
    tgt_data, val_data, test_data, src_data = prepare_data(args)
    inv_model = prepare_model(args)

    print('invariant', inv_model)

    optimizer = build_optimizer(inv_model, args)
    scheduler = build_lr_scheduler(optimizer, args)
    inv_opt = (optimizer, scheduler)

    loss_func = get_loss_func(args)
    metric_func = get_metric_func(metric=args.metric)

    best_score = float('inf') if args.minimize_score else -float('inf')
    best_epoch = 0
    for epoch in range(args.epochs):
        print(f'Epoch {epoch}')
        train(inv_model, src_data, tgt_data, loss_func, inv_opt, args)

        val_scores = evaluate(inv_model, val_data, args.num_tasks, metric_func,
                              args.batch_size, args.dataset_type)
        avg_val_score = np.nanmean(val_scores)
        print(f'Validation {args.metric} = {avg_val_score:.4f}')
        if args.minimize_score and avg_val_score < best_score or not args.minimize_score and avg_val_score > best_score:
            best_score, best_epoch = avg_val_score, epoch
            save_checkpoint(os.path.join(save_dir, 'model.pt'),
                            inv_model,
                            args=args)

    print(f'Loading model checkpoint from epoch {best_epoch}')
    model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda)
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    test_preds = predict(model, test_data, args.batch_size)
    test_scores = evaluate_predictions(test_preds, test_targets,
                                       args.num_tasks, metric_func,
                                       args.dataset_type)

    avg_test_score = np.nanmean(test_scores)
    print(f'Test {args.metric} = {avg_test_score:.4f}')
    return avg_test_score
def run_training_gnn_xgb(args: TrainArgs,
                         logger: Logger = None) -> List[float]:
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.

    :param args: Arguments.
    :param logger: Logger.
    :return: A list of ensemble scores for each task.
    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Print command line
    debug('Command line')
    debug(f'python {" ".join(sys.argv)}')

    # Print args
    debug('Args')
    debug(args)

    # Save args
    args.save(os.path.join(args.save_dir, 'args.json'))

    # Set pytorch seed for random initial weights
    torch.manual_seed(args.pytorch_seed)

    # Get data
    debug('Loading data')
    args.task_names = args.target_columns or get_task_names(args.data_path)
    data = get_data(path=args.data_path, args=args, logger=logger)
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()
    debug(f'Number of tasks = {args.num_tasks}')

    # Split data
    debug(f'Splitting data with seed {args.seed}')
    if args.separate_test_path:
        test_data = get_data(path=args.separate_test_path,
                             args=args,
                             features_path=args.separate_test_features_path,
                             logger=logger)
    if args.separate_val_path:
        val_data = get_data(path=args.separate_val_path,
                            args=args,
                            features_path=args.separate_val_features_path,
                            logger=logger)

    if args.separate_val_path and args.separate_test_path:
        train_data = data
    elif args.separate_val_path:
        train_data, _, test_data = split_data(data=data,
                                              split_type=args.split_type,
                                              sizes=(0.8, 0.0, 0.2),
                                              seed=args.seed,
                                              args=args,
                                              logger=logger)
    elif args.separate_test_path:
        train_data, val_data, _ = split_data(data=data,
                                             split_type=args.split_type,
                                             sizes=(0.8, 0.2, 0.0),
                                             seed=args.seed,
                                             args=args,
                                             logger=logger)
    else:
        train_data, val_data, test_data = split_data(
            data=data,
            split_type=args.split_type,
            sizes=args.split_sizes,
            seed=args.seed,
            args=args,
            logger=logger)

    if args.dataset_type == 'classification':
        class_sizes = get_class_sizes(data)
        debug('Class sizes')
        for i, task_class_sizes in enumerate(class_sizes):
            debug(
                f'{args.task_names[i]} '
                f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}'
            )

    if args.save_smiles_splits:
        save_smiles_splits(train_data=train_data,
                           val_data=val_data,
                           test_data=test_data,
                           data_path=args.data_path,
                           save_dir=args.save_dir)

    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    args.train_data_size = len(train_data)

    debug(
        f'Total size = {len(data):,} | '
        f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}'
    )

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression':
        debug('Fitting scaler')
        train_smiles, train_targets = train_data.smiles(), train_data.targets()
        scaler = StandardScaler().fit(train_targets)
        scaled_targets = scaler.transform(train_targets).tolist()
        train_data.set_targets(scaled_targets)
    else:
        scaler = None

    # Get loss and metric functions
    loss_func = get_loss_func(args)
    metric_func = get_metric_func(metric=args.metric)

    # Set up test set evaluation
    val_smiles, val_targets = val_data.smiles(), val_data.targets()
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    if args.dataset_type == 'multiclass':
        sum_test_preds = np.zeros(
            (len(test_smiles), args.num_tasks, args.multiclass_num_classes))
    else:
        sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Automatically determine whether to cache
    if len(data) <= args.cache_cutoff:
        cache = True
        num_workers = 0
    else:
        cache = False
        num_workers = args.num_workers

    # Create data loaders
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=args.batch_size,
                                           num_workers=num_workers,
                                           cache=cache,
                                           class_balance=args.class_balance,
                                           shuffle=True,
                                           seed=args.seed)
    val_data_loader = MoleculeDataLoader(dataset=val_data,
                                         batch_size=args.batch_size,
                                         num_workers=num_workers,
                                         cache=cache)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers,
                                          cache=cache)

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        makedirs(save_dir)
        try:
            writer = SummaryWriter(log_dir=save_dir)
        except:
            writer = SummaryWriter(logdir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(
                f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}'
            )
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = MoleculeModel(args)

        debug(model)
        debug(f'Number of parameters = {param_count(model):,}')
        if args.cuda:
            debug('Moving model to cuda')
        model = model.to(args.device)

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler,
                        features_scaler, args)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug(f'Epoch {epoch}')

            n_iter = train(model=model,
                           data_loader=train_data_loader,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer)
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(model=model,
                                  data_loader=val_data_loader,
                                  num_tasks=args.num_tasks,
                                  metric_func=metric_func,
                                  dataset_type=args.dataset_type,
                                  scaler=scaler,
                                  logger=logger)

            # Average validation score
            avg_val_score = np.nanmean(val_scores)
            debug(f'Validation {args.metric} = {avg_val_score:.6f}')
            writer.add_scalar(f'validation_{args.metric}', avg_val_score,
                              n_iter)

            if args.show_individual_scores:
                # Individual validation scores
                for task_name, val_score in zip(args.task_names, val_scores):
                    debug(
                        f'Validation {task_name} {args.metric} = {val_score:.6f}'
                    )
                    writer.add_scalar(f'validation_{task_name}_{args.metric}',
                                      val_score, n_iter)

            # Save model checkpoint if improved validation score
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir, 'model.pt'), model,
                                scaler, features_scaler, args)

                # Evaluate on test set using model with best validation score
        info(
            f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}'
        )
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'),
                                device=args.device,
                                logger=logger)

        test_preds, _ = predict(model=model,
                                data_loader=test_data_loader,
                                scaler=scaler)
        test_scores = evaluate_predictions(preds=test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metric_func=metric_func,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

        if len(test_preds) != 0:
            sum_test_preds += np.array(test_preds)

        # Average test score
        avg_test_score = np.nanmean(test_scores)
        info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')
        writer.add_scalar(f'test_{args.metric}', avg_test_score, 0)

        if args.show_individual_scores:
            # Individual test scores
            for task_name, test_score in zip(args.task_names, test_scores):
                info(
                    f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}'
                )
                writer.add_scalar(f'test_{task_name}_{args.metric}',
                                  test_score, n_iter)
        writer.close()

    # Evaluate ensemble on test set
    avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    ensemble_scores = evaluate_predictions(preds=avg_test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metric_func=metric_func,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

    # Average ensemble score
    avg_ensemble_test_score = np.nanmean(ensemble_scores)
    info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')

    # Individual ensemble scores
    if args.show_individual_scores:
        for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
            info(
                f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}'
            )

    _, train_feature = predict(model=model,
                               data_loader=train_data_loader,
                               scaler=scaler)
    _, val_feature = predict(model=model,
                             data_loader=val_data_loader,
                             scaler=scaler)
    _, test_feature = predict(model=model,
                              data_loader=test_data_loader,
                              scaler=scaler)

    return ensemble_scores, train_feature, val_feature, test_feature, train_targets, val_targets, test_targets
Ejemplo n.º 3
0
def run_training(args: Namespace, logger: Logger = None) -> List[float]:
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.

    :param args: Arguments.
    :param logger: Logger.
    :return: A list of ensemble scores for each task.
    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Set GPU
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)

    # Print args
# =============================================================================
#     debug(pformat(vars(args)))
# =============================================================================

# Get data
    debug('Loading data')
    args.task_names = get_task_names(args.data_path)
    data = get_data(path=args.data_path, args=args, logger=logger)
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()
    debug(f'Number of tasks = {args.num_tasks}')

    # Split data
    debug(f'Splitting data with seed {args.seed}')
    if args.separate_test_path:
        test_data = get_data(path=args.separate_test_path,
                             args=args,
                             features_path=args.separate_test_features_path,
                             logger=logger)
    if args.separate_val_path:
        val_data = get_data(path=args.separate_val_path,
                            args=args,
                            features_path=args.separate_val_features_path,
                            logger=logger)

    if args.separate_val_path and args.separate_test_path:
        train_data = data
    elif args.separate_val_path:
        train_data, _, test_data = split_data(data=data,
                                              split_type=args.split_type,
                                              sizes=(0.8, 0.2, 0.0),
                                              seed=args.seed,
                                              args=args,
                                              logger=logger)
    elif args.separate_test_path:
        train_data, val_data, _ = split_data(data=data,
                                             split_type=args.split_type,
                                             sizes=(0.8, 0.2, 0.0),
                                             seed=args.seed,
                                             args=args,
                                             logger=logger)
    else:
        print('=' * 100)
        train_data, val_data, test_data = split_data(
            data=data,
            split_type=args.split_type,
            sizes=args.split_sizes,
            seed=args.seed,
            args=args,
            logger=logger)

        ###my_code###
        train_df = get_data_df(train_data)
        train_df.to_csv(
            '~/PycharmProjects/CMPNN-master/data/24w_train_df_seed0.csv')
        val_df = get_data_df(val_data)
        val_df.to_csv(
            '~/PycharmProjects/CMPNN-master/data/24w_val_df_seed0.csv')
        test_df = get_data_df(test_data)
        test_df.to_csv(
            '~/PycharmProjects/CMPNN-master/data/24w_test_df_seed0.csv')

        ##########

    if args.dataset_type == 'classification':
        class_sizes = get_class_sizes(data)
        debug('Class sizes')
        for i, task_class_sizes in enumerate(class_sizes):
            debug(
                f'{args.task_names[i]} '
                f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}'
            )

    if args.save_smiles_splits:
        with open(args.data_path, 'r') as f:
            reader = csv.reader(f)
            header = next(reader)

            lines_by_smiles = {}
            indices_by_smiles = {}
            for i, line in enumerate(reader):
                smiles = line[0]
                lines_by_smiles[smiles] = line
                indices_by_smiles[smiles] = i

        all_split_indices = []
        for dataset, name in [(train_data, 'train'), (val_data, 'val'),
                              (test_data, 'test')]:
            with open(os.path.join(args.save_dir, name + '_smiles.csv'),
                      'w') as f:
                writer = csv.writer(f)
                writer.writerow(['smiles'])
                for smiles in dataset.smiles():
                    writer.writerow([smiles])
            with open(os.path.join(args.save_dir, name + '_full.csv'),
                      'w') as f:
                writer = csv.writer(f)
                writer.writerow(header)
                for smiles in dataset.smiles():
                    writer.writerow(lines_by_smiles[smiles])
            split_indices = []
            for smiles in dataset.smiles():
                split_indices.append(indices_by_smiles[smiles])
                split_indices = sorted(split_indices)
            all_split_indices.append(split_indices)
        with open(os.path.join(args.save_dir, 'split_indices.pckl'),
                  'wb') as f:
            pickle.dump(all_split_indices, f)

    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    args.train_data_size = len(train_data)

    debug(
        f'Total size = {len(data):,} | '
        f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}'
    )

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression':
        debug('Fitting scaler')
        train_smiles, train_targets = train_data.smiles(), train_data.targets()
        scaler = StandardScaler().fit(train_targets)
        scaled_targets = scaler.transform(train_targets).tolist()
        train_data.set_targets(scaled_targets)
    else:
        scaler = None

    # Get loss and metric functions
    loss_func = get_loss_func(args)
    metric_func = get_metric_func(metric=args.metric)

    # Set up test set evaluation
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    if args.dataset_type == 'multiclass':
        sum_test_preds = np.zeros(
            (len(test_smiles), args.num_tasks, args.multiclass_num_classes))
    else:
        sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        makedirs(save_dir)
        try:
            writer = SummaryWriter(log_dir=save_dir)
        except:
            writer = SummaryWriter(logdir=save_dir)
        # Load/build model
        if args.checkpoint_paths is not None:
            debug(
                f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}'
            )
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    current_args=args,
                                    logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = build_model(args)

        debug(model)
        debug(f'Number of parameters = {param_count(model):,}')
        if args.cuda:
            debug('Moving model to cuda')
            model = model.cuda()

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler,
                        features_scaler, args)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in range(args.epochs):
            debug(f'Epoch {epoch}')

            n_iter = train(model=model,
                           data=train_data,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer)
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(model=model,
                                  data=val_data,
                                  num_tasks=args.num_tasks,
                                  metric_func=metric_func,
                                  batch_size=args.batch_size,
                                  dataset_type=args.dataset_type,
                                  scaler=scaler,
                                  logger=logger)

            # Average validation score
            avg_val_score = np.nanmean(val_scores)
            debug(f'Validation {args.metric} = {avg_val_score:.6f}')
            writer.add_scalar(f'validation_{args.metric}', avg_val_score,
                              n_iter)

            if args.show_individual_scores:
                # Individual validation scores
                for task_name, val_score in zip(args.task_names, val_scores):
                    debug(
                        f'Validation {task_name} {args.metric} = {val_score:.6f}'
                    )
                    writer.add_scalar(f'validation_{task_name}_{args.metric}',
                                      val_score, n_iter)

            # Save model checkpoint if improved validation score
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir, 'model.pt'), model,
                                scaler, features_scaler, args)

        # Evaluate on test set using model with best validation score
        info(
            f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}'
        )
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'),
                                cuda=args.cuda,
                                logger=logger)

        test_preds = predict(model=model,
                             data=test_data,
                             batch_size=args.batch_size,
                             scaler=scaler)
        test_scores = evaluate_predictions(preds=test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metric_func=metric_func,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

        if len(test_preds) != 0:
            sum_test_preds += np.array(test_preds)

        # Average test score
        avg_test_score = np.nanmean(test_scores)
        info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')
        writer.add_scalar(f'test_{args.metric}', avg_test_score, 0)

        if args.show_individual_scores:
            # Individual test scores
            for task_name, test_score in zip(args.task_names, test_scores):
                info(
                    f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}'
                )
                writer.add_scalar(f'test_{task_name}_{args.metric}',
                                  test_score, n_iter)

    # Evaluate ensemble on test set
    avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    ensemble_scores = evaluate_predictions(preds=avg_test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metric_func=metric_func,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

    # Average ensemble score
    avg_ensemble_test_score = np.nanmean(ensemble_scores)
    info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')
    writer.add_scalar(f'ensemble_test_{args.metric}', avg_ensemble_test_score,
                      0)

    # Individual ensemble scores
    if args.show_individual_scores:
        for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
            info(
                f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}'
            )

    return ensemble_scores
Ejemplo n.º 4
0
def predict(model: nn.Module,
            data: MoleculeDataset,
            args: Namespace,
            scaler: StandardScaler = None,
            bert_save_memory: bool = False,
            logger: logging.Logger = None) -> List[List[float]]:
    """
    Makes predictions on a dataset using an ensemble of models.

    :param model: A model.
    :param data: A MoleculeDataset.
    :param args: Arguments.
    :param scaler: A StandardScaler object fit on the training targets.
    :param bert_save_memory: Store unused predictions as None to avoid unnecessary memory use.
    :param logger: Logger.
    :return: A list of lists of predictions. The outer list is examples
    while the inner list is tasks.
    """
    model.eval()

    preds = []
    if args.dataset_type == 'bert_pretraining':
        features_preds = []

    if args.maml:
        num_iters, iter_step = data.num_tasks() * args.maml_batches_per_epoch, 1
        full_targets = []
    else:
        num_iters, iter_step = len(data), args.batch_size
    
    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, iter_step, exit_queue, True))
        batch_process.start()
        currently_loaded_batches = []

    for i in trange(0, num_iters, iter_step):
        if args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(args, seed=0)
            mol_batch = task_test_data
            smiles_batch, features_batch, targets_batch = task_train_data.smiles(), task_train_data.features(), task_train_data.targets(task_idx)
            targets = torch.Tensor(targets_batch).unsqueeze(1)
            if args.cuda:
                targets = targets.cuda()
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(0)
            else:
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features()

        # Run model
        if args.dataset_type == 'bert_pretraining':
            batch = mol2graph(smiles_batch, args)
            batch.bert_mask(mol_batch.mask())
        else:
            batch = smiles_batch
        
        if args.maml:  # TODO refactor with train loop
            model.zero_grad()
            intermediate_preds = model(batch, features_batch)
            loss = get_loss_func(args)(intermediate_preds, targets)
            loss = loss.sum() / len(batch)
            grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad])
            theta = [p for p in model.named_parameters() if p[1].requires_grad]  # comes in same order as grad
            theta_prime = {p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta)}
            for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]:
                theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param)
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, targets_batch = task_test_data.smiles(), task_test_data.features(), task_test_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            with torch.no_grad():
                batch_preds = model_prime(smiles_batch, features_batch)
            full_targets.extend([[t] for t in targets_batch])
        else:
            with torch.no_grad():
                if args.parallel_featurization:
                    previous_graph_input_mode = model.encoder.graph_input
                    model.encoder.graph_input = True  # force model to accept already processed input
                    batch_preds = model(featurized_mol_batch, features_batch)
                    model.encoder.graph_input = previous_graph_input_mode
                else:
                    batch_preds = model(batch, features_batch)

                if args.dataset_type == 'bert_pretraining':
                    if batch_preds['features'] is not None:
                        features_preds.extend(batch_preds['features'].data.cpu().numpy())
                    batch_preds = batch_preds['vocab']
                
                if args.dataset_type == 'kernel':
                    batch_preds = batch_preds.view(int(batch_preds.size(0)/2), 2, batch_preds.size(1))
                    batch_preds = model.kernel_output_layer(batch_preds)

        batch_preds = batch_preds.data.cpu().numpy()

        if scaler is not None:
            batch_preds = scaler.inverse_transform(batch_preds)
        
        if args.dataset_type == 'regression_with_binning':
            batch_preds = batch_preds.reshape((batch_preds.shape[0], args.num_tasks, args.num_bins))
            indices = np.argmax(batch_preds, axis=2)
            preds.extend(indices.tolist())
        else:
            batch_preds = batch_preds.tolist()
            if args.dataset_type == 'bert_pretraining' and bert_save_memory:
                for atom_idx, mask_val in enumerate(mol_batch.mask()):
                    if mask_val != 0:
                        batch_preds[atom_idx] = None  # not going to predict, so save some memory when passing around
            preds.extend(batch_preds)
    
    if args.dataset_type == 'regression_with_binning':
        preds = args.bin_predictions[np.array(preds)].tolist()

    if args.dataset_type == 'bert_pretraining':
        preds = {
            'features': features_preds if len(features_preds) > 0 else None,
            'vocab': preds
        }

    if args.parallel_featurization:
        exit_queue.put(0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    if args.maml:
        # return the task targets here to guarantee alignment;
        # there's probably no reasonable scenario where we'd use MAML directly to predict something that's actually unknown
        return preds, full_targets

    return preds
Ejemplo n.º 5
0
def run_training(args: TrainArgs,
                 data: MoleculeDataset,
                 logger: Logger = None) -> Dict[str, List[float]]:
    """
    Loads data, trains a Chemprop model, and returns test scores for the model checkpoint with the highest validation score.

    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
                 loading data and training the Chemprop model.
    :param data: A :class:`~chemprop.data.MoleculeDataset` containing the data.
    :param logger: A logger to record output.
    :return: A dictionary mapping each metric in :code:`args.metrics` to a list of values for each task.

    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Set pytorch seed for random initial weights
    torch.manual_seed(args.pytorch_seed)

    # Split data
    debug(f'Splitting data with seed {args.seed}')
    if args.separate_test_path:
        test_data = get_data(
            path=args.separate_test_path,
            args=args,
            features_path=args.separate_test_features_path,
            atom_descriptors_path=args.separate_test_atom_descriptors_path,
            bond_features_path=args.separate_test_bond_features_path,
            phase_features_path=args.separate_test_phase_features_path,
            smiles_columns=args.smiles_columns,
            logger=logger)
    if args.separate_val_path:
        val_data = get_data(
            path=args.separate_val_path,
            args=args,
            features_path=args.separate_val_features_path,
            atom_descriptors_path=args.separate_val_atom_descriptors_path,
            bond_features_path=args.separate_val_bond_features_path,
            phase_features_path=args.separate_val_phase_features_path,
            smiles_columns=args.smiles_columns,
            logger=logger)

    if args.separate_val_path and args.separate_test_path:
        train_data = data
    elif args.separate_val_path:
        train_data, _, test_data = split_data(data=data,
                                              split_type=args.split_type,
                                              sizes=(0.8, 0.0, 0.2),
                                              seed=args.seed,
                                              num_folds=args.num_folds,
                                              args=args,
                                              logger=logger)
    elif args.separate_test_path:
        train_data, val_data, _ = split_data(data=data,
                                             split_type=args.split_type,
                                             sizes=(0.8, 0.2, 0.0),
                                             seed=args.seed,
                                             num_folds=args.num_folds,
                                             args=args,
                                             logger=logger)
    else:
        train_data, val_data, test_data = split_data(
            data=data,
            split_type=args.split_type,
            sizes=args.split_sizes,
            seed=args.seed,
            num_folds=args.num_folds,
            args=args,
            logger=logger)

    if args.dataset_type == 'classification':
        class_sizes = get_class_sizes(data)
        debug('Class sizes')
        for i, task_class_sizes in enumerate(class_sizes):
            debug(
                f'{args.task_names[i]} '
                f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}'
            )

    if args.save_smiles_splits:
        save_smiles_splits(
            data_path=args.data_path,
            save_dir=args.save_dir,
            task_names=args.task_names,
            features_path=args.features_path,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            smiles_columns=args.smiles_columns,
            logger=logger,
        )

    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    if args.atom_descriptor_scaling and args.atom_descriptors is not None:
        atom_descriptor_scaler = train_data.normalize_features(
            replace_nan_token=0, scale_atom_descriptors=True)
        val_data.normalize_features(atom_descriptor_scaler,
                                    scale_atom_descriptors=True)
        test_data.normalize_features(atom_descriptor_scaler,
                                     scale_atom_descriptors=True)
    else:
        atom_descriptor_scaler = None

    if args.bond_feature_scaling and args.bond_features_size > 0:
        bond_feature_scaler = train_data.normalize_features(
            replace_nan_token=0, scale_bond_features=True)
        val_data.normalize_features(bond_feature_scaler,
                                    scale_bond_features=True)
        test_data.normalize_features(bond_feature_scaler,
                                     scale_bond_features=True)
    else:
        bond_feature_scaler = None

    args.train_data_size = len(train_data)

    debug(
        f'Total size = {len(data):,} | '
        f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}'
    )

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression':
        debug('Fitting scaler')
        scaler = train_data.normalize_targets()
    elif args.dataset_type == 'spectra':
        debug(
            'Normalizing spectra and excluding spectra regions based on phase')
        args.spectra_phase_mask = load_phase_mask(args.spectra_phase_mask_path)
        for dataset in [train_data, test_data, val_data]:
            data_targets = normalize_spectra(
                spectra=dataset.targets(),
                phase_features=dataset.phase_features(),
                phase_mask=args.spectra_phase_mask,
                excluded_sub_value=None,
                threshold=args.spectra_target_floor,
            )
            dataset.set_targets(data_targets)
        scaler = None
    else:
        scaler = None

    # Get loss function
    loss_func = get_loss_func(args)

    # Set up test set evaluation
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    if args.dataset_type == 'multiclass':
        sum_test_preds = np.zeros(
            (len(test_smiles), args.num_tasks, args.multiclass_num_classes))
    else:
        sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Automatically determine whether to cache
    if len(data) <= args.cache_cutoff:
        set_cache_graph(True)
        num_workers = 0
    else:
        set_cache_graph(False)
        num_workers = args.num_workers

    # Create data loaders
    train_data_loader = MoleculeDataLoader(dataset=train_data,
                                           batch_size=args.batch_size,
                                           num_workers=num_workers,
                                           class_balance=args.class_balance,
                                           shuffle=True,
                                           seed=args.seed)
    val_data_loader = MoleculeDataLoader(dataset=val_data,
                                         batch_size=args.batch_size,
                                         num_workers=num_workers)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers)

    if args.class_balance:
        debug(
            f'With class_balance, effective train size = {train_data_loader.iter_size:,}'
        )

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        makedirs(save_dir)
        try:
            writer = SummaryWriter(log_dir=save_dir)
        except:
            writer = SummaryWriter(logdir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(
                f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}'
            )
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = MoleculeModel(args)

        # Optionally, overwrite weights:
        if args.checkpoint_frzn is not None:
            debug(
                f'Loading and freezing parameters from {args.checkpoint_frzn}.'
            )
            model = load_frzn_model(model=model,
                                    path=args.checkpoint_frzn,
                                    current_args=args,
                                    logger=logger)

        debug(model)

        if args.checkpoint_frzn is not None:
            debug(f'Number of unfrozen parameters = {param_count(model):,}')
            debug(f'Total number of parameters = {param_count_all(model):,}')
        else:
            debug(f'Number of parameters = {param_count_all(model):,}')

        if args.cuda:
            debug('Moving model to cuda')
        model = model.to(args.device)

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME), model, scaler,
                        features_scaler, atom_descriptor_scaler,
                        bond_feature_scaler, args)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug(f'Epoch {epoch}')
            n_iter = train(model=model,
                           data_loader=train_data_loader,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer)
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(model=model,
                                  data_loader=val_data_loader,
                                  num_tasks=args.num_tasks,
                                  metrics=args.metrics,
                                  dataset_type=args.dataset_type,
                                  scaler=scaler,
                                  logger=logger)

            for metric, scores in val_scores.items():
                # Average validation score
                avg_val_score = np.nanmean(scores)
                debug(f'Validation {metric} = {avg_val_score:.6f}')
                writer.add_scalar(f'validation_{metric}', avg_val_score,
                                  n_iter)

                if args.show_individual_scores:
                    # Individual validation scores
                    for task_name, val_score in zip(args.task_names, scores):
                        debug(
                            f'Validation {task_name} {metric} = {val_score:.6f}'
                        )
                        writer.add_scalar(f'validation_{task_name}_{metric}',
                                          val_score, n_iter)

            # Save model checkpoint if improved validation score
            avg_val_score = np.nanmean(val_scores[args.metric])
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir,
                                             MODEL_FILE_NAME), model, scaler,
                                features_scaler, atom_descriptor_scaler,
                                bond_feature_scaler, args)

        # Evaluate on test set using model with best validation score
        info(
            f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}'
        )
        model = load_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME),
                                device=args.device,
                                logger=logger)

        test_preds = predict(model=model,
                             data_loader=test_data_loader,
                             scaler=scaler)
        test_scores = evaluate_predictions(preds=test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metrics=args.metrics,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

        if len(test_preds) != 0:
            sum_test_preds += np.array(test_preds)

        # Average test score
        for metric, scores in test_scores.items():
            avg_test_score = np.nanmean(scores)
            info(f'Model {model_idx} test {metric} = {avg_test_score:.6f}')
            writer.add_scalar(f'test_{metric}', avg_test_score, 0)

            if args.show_individual_scores and args.dataset_type != 'spectra':
                # Individual test scores
                for task_name, test_score in zip(args.task_names, scores):
                    info(
                        f'Model {model_idx} test {task_name} {metric} = {test_score:.6f}'
                    )
                    writer.add_scalar(f'test_{task_name}_{metric}', test_score,
                                      n_iter)
        writer.close()

    # Evaluate ensemble on test set
    avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    ensemble_scores = evaluate_predictions(preds=avg_test_preds,
                                           targets=test_targets,
                                           num_tasks=args.num_tasks,
                                           metrics=args.metrics,
                                           dataset_type=args.dataset_type,
                                           logger=logger)

    for metric, scores in ensemble_scores.items():
        # Average ensemble score
        avg_ensemble_test_score = np.nanmean(scores)
        info(f'Ensemble test {metric} = {avg_ensemble_test_score:.6f}')

        # Individual ensemble scores
        if args.show_individual_scores:
            for task_name, ensemble_score in zip(args.task_names, scores):
                info(
                    f'Ensemble test {task_name} {metric} = {ensemble_score:.6f}'
                )

    # Save scores
    with open(os.path.join(args.save_dir, 'test_scores.json'), 'w') as f:
        json.dump(ensemble_scores, f, indent=4, sort_keys=True)

    # Optionally save test preds
    if args.save_preds:
        test_preds_dataframe = pd.DataFrame(
            data={'smiles': test_data.smiles()})

        for i, task_name in enumerate(args.task_names):
            test_preds_dataframe[task_name] = [
                pred[i] for pred in avg_test_preds
            ]

        test_preds_dataframe.to_csv(os.path.join(args.save_dir,
                                                 'test_preds.csv'),
                                    index=False)

    return ensemble_scores
Ejemplo n.º 6
0
def main(args: Namespace):
    print('Loading data')
    dataset = get_data(args.data_path)
    num_mols, num_tasks = len(dataset), dataset.num_tasks()
    args.num_mols, args.num_tasks = num_mols, num_tasks
    data = convert_moleculedataset_to_moleculefactordataset(dataset)

    print(f'Number of molecules = {num_mols:,}')
    print(f'Number of tasks = {num_tasks:,}')
    print(f'Number of real tasks = {args.num_real_tasks:,}')
    print(f'Number of known molecule-task pairs = {len(data):,}')

    print('Splitting data')
    train_data, val_data, test_data = split_data(data, args)

    if args.dataset_type == 'regression':
        print('Scaling data')
        targets = train_data.targets()
        scaler = StandardScaler().fit(targets)
        scaled_targets = scaler.transform(targets).tolist()
        train_data.set_targets(scaled_targets)
    else:
        scaler = None

    print(
        f'Total size = {len(data):,} | '
        f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}'
    )

    if args.checkpoint_path is not None:
        print('Loading saved model')
        model, loaded_args = load(args.checkpoint_path)
        assert args.num_mols == loaded_args.num_mols and args.num_tasks == loaded_args.num_tasks
        args.embedding_size, args.hidden_size, args.dropout, args.activation, args.dataset_type = \
            loaded_args.embedding_size, loaded_args.hidden_size, loaded_args.dropout, loaded_args.activation, loaded_args.dataset_type
    else:
        print('Building model')
        model = MatrixFactorizer(
            args,
            num_mols=num_mols,
            num_tasks=num_tasks,
            embedding_size=args.embedding_size,
            hidden_size=args.hidden_size,
            dropout=args.dropout,
            activation=args.activation,
            classification=(args.dataset_type == 'classification'))

    print(model)
    print(f'Number of parameters = {param_count(model):,}')

    loss_func = get_loss_func(args)
    metric_func = get_metric_func(metric=args.metric)
    optimizer = Adam(model.parameters(), lr=args.lr)

    if args.cuda:
        print('Moving model to cuda')
        model = model.cuda()

    print('Training')
    for epoch in trange(args.epochs):
        print(f'Epoch {epoch}')
        train(model=model,
              data=train_data,
              loss_func=loss_func,
              optimizer=optimizer,
              batch_size=args.batch_size,
              random_mol_embeddings=args.random_mol_embeddings)
        val_scores = evaluate(model=model,
                              data=val_data,
                              num_tasks=args.num_real_tasks,
                              metric_func=metric_func,
                              batch_size=args.batch_size,
                              scaler=scaler,
                              random_mol_embeddings=args.random_mol_embeddings)
        print(val_scores)
        avg_val_score = np.mean(val_scores)
        print(f'Validation {args.metric} = {avg_val_score:.6f}')

    test_scores = evaluate(model=model,
                           data=test_data,
                           num_tasks=args.num_real_tasks,
                           metric_func=metric_func,
                           batch_size=args.batch_size,
                           scaler=scaler,
                           random_mol_embeddings=args.random_mol_embeddings)
    print(test_scores)
    avg_test_score = np.mean(test_scores)
    print(f'Test {args.metric} = {avg_test_score:.6f}')

    if args.save_path is not None:
        print('Saving model')
        save(model, args, args.save_path)

    if args.filled_matrix_path is not None:
        print('Filling matrix of data')
        fill_matrix(model, args, data)
Ejemplo n.º 7
0
def run_training(args: Namespace, logger: Logger = None):
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.
    :param args: args info
    :param logger: logger info
    :return: Optimal average test score (for use in hyperparameter optimization via Hyperopt)
    """

    # Set up logger
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    debug(pformat(vars(args)))

    # Load metadata
    metadata = json.load(open(args.data_path, 'r'))

    # Train/val/test split
    if args.k_fold_split:
        data_splits = []
        kf = KFold(n_splits=args.num_folds, shuffle=True, random_state=args.seed)
        for train_index, test_index in kf.split(metadata):
            splits = [train_index, test_index]
            data_splits.append(splits)
        data_splits = data_splits[args.fold_index]

        if args.use_inner_test:
            train_indices, remaining_indices = train_test_split(data_splits[0], test_size=args.val_test_size,
                                                                random_state=args.seed)
            validation_indices, test_indices = train_test_split(remaining_indices, test_size=0.5,
                                                                random_state=args.seed)

        else:
            train_indices = data_splits[0]
            validation_indices, test_indices = train_test_split(data_splits[1], test_size=0.5, random_state=args.seed)

        train_metadata = list(np.asarray(metadata)[list(train_indices)])
        validation_metadata = list(np.asarray(metadata)[list(validation_indices)])
        test_metadata = list(np.asarray(metadata)[list(test_indices)])

    else:
        train_metadata, remaining_metadata = train_test_split(metadata, test_size=args.val_test_size,
                                                              random_state=args.seed)
        validation_metadata, test_metadata = train_test_split(remaining_metadata, test_size=0.5, random_state=args.seed)

    # Load datasets
    debug('Loading data')
    transform = Compose([Augmentation(args.augmentation_length), NNGraph(args.num_neighbors), Distance(False)])
    train_data = GlassDataset(train_metadata, transform=transform)
    val_data = GlassDataset(validation_metadata, transform=transform)
    test_data = GlassDataset(test_metadata, transform=transform)
    args.atom_fdim = 3
    args.bond_fdim = args.atom_fdim + 1

    # Dataset lengths
    train_data_length, val_data_length, test_data_length = len(train_data), len(val_data), len(test_data)
    debug('train size = {:,} | val size = {:,} | test size = {:,}'.format(
        train_data_length,
        val_data_length,
        test_data_length)
    )

    # Convert to iterators
    train_data = DataLoader(train_data, args.batch_size)
    val_data = DataLoader(val_data, args.batch_size)
    test_data = DataLoader(test_data, args.batch_size)

    # Get loss and metric functions
    loss_func = get_loss_func(args)
    metric_func = get_metric_func(args.metric)

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, 'model_{}'.format(model_idx))
        os.makedirs(save_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug('Loading model {} from {}'.format(model_idx, args.checkpoint_paths[model_idx]))
            model = load_checkpoint(args.checkpoint_paths[model_idx], args.save_dir, attention_viz=args.attention_viz)
        else:
            debug('Building model {}'.format(model_idx))
            model = build_model(args)

        debug(model)
        debug('Number of parameters = {:,}'.format(param_count(model)))

        if args.cuda:
            debug('Moving model to cuda')
            model = model.cuda()

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(model, args, os.path.join(save_dir, 'model.pt'))

        # Optimizer and learning rate scheduler
        optimizer = Adam(model.parameters(), lr=args.init_lr[model_idx], weight_decay=args.weight_decay[model_idx])

        scheduler = NoamLR(
            optimizer,
            warmup_epochs=args.warmup_epochs,
            total_epochs=[args.epochs],
            steps_per_epoch=train_data_length // args.batch_size,
            init_lr=args.init_lr,
            max_lr=args.max_lr,
            final_lr=args.final_lr
        )

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug('Epoch {}'.format(epoch))

            n_iter = train(
                model=model,
                data=train_data,
                loss_func=loss_func,
                optimizer=optimizer,
                scheduler=scheduler,
                args=args,
                n_iter=n_iter,
                logger=logger,
                writer=writer
            )

            val_scores = []
            for val_runs in range(args.num_val_runs):

                val_batch_scores = evaluate(
                    model=model,
                    data=val_data,
                    metric_func=metric_func,
                    args=args,
                )

                val_scores.append(np.mean(val_batch_scores))

            # Average validation score
            avg_val_score = np.mean(val_scores)
            debug('Validation {} = {:.3f}'.format(args.metric, avg_val_score))
            writer.add_scalar('validation_{}'.format(args.metric), avg_val_score, n_iter)

            # Save model checkpoint if improved validation score
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(model, args, os.path.join(save_dir, 'model.pt'))

        # Evaluate on test set using model with best validation score
        info('Model {} best validation {} = {:.3f} on epoch {}'.format(model_idx, args.metric, best_score, best_epoch))
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'), args.save_dir, cuda=args.cuda,
                                attention_viz=args.attention_viz)

        test_scores = []
        for test_runs in range(args.num_test_runs):

            test_batch_scores = evaluate(
                model=model,
                data=test_data,
                metric_func=metric_func,
                args=args
            )

            test_scores.append(np.mean(test_batch_scores))

        # Get accuracy (assuming args.metric is set to AUC)
        metric_func_accuracy = get_metric_func('accuracy')
        test_scores_accuracy = []
        for test_runs in range(args.num_test_runs):

            test_batch_scores = evaluate(
                model=model,
                data=test_data,
                metric_func=metric_func_accuracy,
                args=args
            )

            test_scores_accuracy.append(np.mean(test_batch_scores))

        # Average test score
        avg_test_score = np.mean(test_scores)
        avg_test_accuracy = np.mean(test_scores_accuracy)
        info('Model {} test {} = {:.3f}, test {} = {:.3f}'.format(model_idx, args.metric,
                                                                  avg_test_score, 'accuracy', avg_test_accuracy))
        writer.add_scalar('test_{}'.format(args.metric), avg_test_score, n_iter)

        return avg_test_score, avg_test_accuracy  # For hyperparameter optimization or cross validation use
Ejemplo n.º 8
0
def run_training(args: Namespace, logger: Logger = None) -> List[float]:
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.

    :param args: Arguments.
    :param logger: Logger.
    :return: A list of ensemble scores for each task.
    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Set GPU
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)

    # Print args
    debug(pformat(vars(args)))

    # Get data
    debug('Loading data')
    args.task_names = get_task_names(args.data_path)
    desired_labels = get_desired_labels(args, args.task_names)
    data = get_data(path=args.data_path, args=args, logger=logger)
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()
    args.real_num_tasks = args.num_tasks - args.features_size if args.predict_features else args.num_tasks
    debug(f'Number of tasks = {args.num_tasks}')

    if args.dataset_type == 'bert_pretraining':
        data.bert_init(args, logger)

    # Split data
    if args.dataset_type == 'regression_with_binning':  # Note: for now, binning based on whole dataset, not just training set
        data, bin_predictions, regression_data = data
        args.bin_predictions = bin_predictions
        debug(f'Splitting data with seed {args.seed}')
        train_data, _, _ = split_data(data=data,
                                      split_type=args.split_type,
                                      sizes=args.split_sizes,
                                      seed=args.seed,
                                      args=args,
                                      logger=logger)
        _, val_data, test_data = split_data(regression_data,
                                            split_type=args.split_type,
                                            sizes=args.split_sizes,
                                            seed=args.seed,
                                            args=args,
                                            logger=logger)
    else:
        debug(f'Splitting data with seed {args.seed}')
        if args.separate_test_set:
            test_data = get_data(path=args.separate_test_set,
                                 args=args,
                                 features_path=args.separate_test_set_features,
                                 logger=logger)
            if args.separate_val_set:
                val_data = get_data(
                    path=args.separate_val_set,
                    args=args,
                    features_path=args.separate_val_set_features,
                    logger=logger)
                train_data = data  # nothing to split; we already got our test and val sets
            else:
                train_data, val_data, _ = split_data(
                    data=data,
                    split_type=args.split_type,
                    sizes=(0.8, 0.2, 0.0),
                    seed=args.seed,
                    args=args,
                    logger=logger)
        else:
            train_data, val_data, test_data = split_data(
                data=data,
                split_type=args.split_type,
                sizes=args.split_sizes,
                seed=args.seed,
                args=args,
                logger=logger)

    # Optionally replace test data with train or val data
    if args.test_split == 'train':
        test_data = train_data
    elif args.test_split == 'val':
        test_data = val_data

    if args.dataset_type == 'classification':
        class_sizes = get_class_sizes(data)
        debug('Class sizes')
        for i, task_class_sizes in enumerate(class_sizes):
            debug(
                f'{args.task_names[i]} '
                f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}'
            )

        if args.class_balance:
            train_class_sizes = get_class_sizes(train_data)
            class_batch_counts = torch.Tensor(
                train_class_sizes) * args.batch_size
            args.class_weights = 1 / torch.Tensor(class_batch_counts)

    if args.save_smiles_splits:
        with open(args.data_path, 'r') as f:
            reader = csv.reader(f)
            header = next(reader)

            lines_by_smiles = {}
            indices_by_smiles = {}
            for i, line in enumerate(reader):
                smiles = line[0]
                lines_by_smiles[smiles] = line
                indices_by_smiles[smiles] = i

        all_split_indices = []
        for dataset, name in [(train_data, 'train'), (val_data, 'val'),
                              (test_data, 'test')]:
            with open(os.path.join(args.save_dir, name + '_smiles.csv'),
                      'w') as f:
                writer = csv.writer(f)
                writer.writerow(['smiles'])
                for smiles in dataset.smiles():
                    writer.writerow([smiles])
            with open(os.path.join(args.save_dir, name + '_full.csv'),
                      'w') as f:
                writer = csv.writer(f)
                writer.writerow(header)
                for smiles in dataset.smiles():
                    writer.writerow(lines_by_smiles[smiles])
            split_indices = []
            for smiles in dataset.smiles():
                split_indices.append(indices_by_smiles[smiles])
                split_indices = sorted(split_indices)
            all_split_indices.append(split_indices)
        with open(os.path.join(args.save_dir, 'split_indices.pckl'),
                  'wb') as f:
            pickle.dump(all_split_indices, f)
        return [1 for _ in range(args.num_tasks)
                ]  # short circuit out when just generating splits

    if args.features_scaling:
        features_scaler = train_data.normalize_features(
            replace_nan_token=None if args.predict_features else 0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    args.train_data_size = len(
        train_data
    ) if args.prespecified_chunk_dir is None else args.prespecified_chunks_max_examples_per_epoch

    if args.adversarial or args.moe:
        val_smiles, test_smiles = val_data.smiles(), test_data.smiles()

    debug(
        f'Total size = {len(data):,} | '
        f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}'
    )

    # Optionally truncate outlier values
    if args.truncate_outliers:
        print('Truncating outliers in train set')
        train_data = truncate_outliers(train_data)

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == 'regression' and args.target_scaling:
        debug('Fitting scaler')
        train_smiles, train_targets = train_data.smiles(), train_data.targets()
        scaler = StandardScaler().fit(train_targets)
        scaled_targets = scaler.transform(train_targets).tolist()
        train_data.set_targets(scaled_targets)
    else:
        scaler = None

    if args.moe:
        train_data = cluster_split(train_data,
                                   args.num_sources,
                                   args.cluster_max_ratio,
                                   seed=args.cluster_split_seed,
                                   logger=logger)

    # Chunk training data if too large to load in memory all at once
    if args.num_chunks > 1:
        os.makedirs(args.chunk_temp_dir, exist_ok=True)
        train_paths = []
        if args.moe:
            chunked_sources = [td.chunk(args.num_chunks) for td in train_data]
            chunks = []
            for i in range(args.num_chunks):
                chunks.append([source[i] for source in chunked_sources])
        else:
            chunks = train_data.chunk(args.num_chunks)
        for i in range(args.num_chunks):
            chunk_path = os.path.join(args.chunk_temp_dir, str(i) + '.txt')
            memo_path = os.path.join(args.chunk_temp_dir,
                                     'memo' + str(i) + '.txt')
            with open(chunk_path, 'wb') as f:
                pickle.dump(chunks[i], f)
            train_paths.append((chunk_path, memo_path))
        train_data = train_paths

    # Get loss and metric functions
    loss_func = get_loss_func(args)
    metric_func = get_metric_func(metric=args.metric, args=args)

    # Set up test set evaluation
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    if args.maml:  # TODO refactor
        test_targets = []
        for task_idx in range(len(data.data[0].targets)):
            _, task_test_data, _ = test_data.sample_maml_task(args, seed=0)
            test_targets += task_test_data.targets()

    if args.dataset_type == 'bert_pretraining':
        sum_test_preds = {
            'features':
            np.zeros((len(test_smiles), args.features_size))
            if args.features_size is not None else None,
            'vocab':
            np.zeros((len(test_targets['vocab']), args.vocab.output_size))
        }
    elif args.dataset_type == 'kernel':
        sum_test_preds = np.zeros((len(test_targets), args.num_tasks))
    else:
        sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    if args.maml:
        sum_test_preds = None  # annoying to determine exact size; will initialize later

    if args.dataset_type == 'bert_pretraining':
        # Only predict targets that are masked out
        test_targets['vocab'] = [
            target if mask == 0 else None
            for target, mask in zip(test_targets['vocab'], test_data.mask())
        ]

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        os.makedirs(save_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(
                f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}'
            )
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    current_args=args,
                                    logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = build_model(args)

        debug(model)
        debug(f'Number of parameters = {param_count(model):,}')
        if args.cuda:
            debug('Moving model to cuda')
            model = model.cuda()

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler,
                        features_scaler, args)

        if args.adjust_weight_decay:
            args.pnorm_target = compute_pnorm(model)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug(f'Epoch {epoch}')

            if args.prespecified_chunk_dir is not None:
                # load some different random chunks each epoch
                train_data, val_data = load_prespecified_chunks(args, logger)
                debug('Loaded prespecified chunks for epoch')

            if args.dataset_type == 'unsupervised':  # won't work with moe
                full_data = MoleculeDataset(train_data.data + val_data.data)
                generate_unsupervised_cluster_labels(
                    build_model(args), full_data,
                    args)  # cluster with a new random init
                model.create_ffn(
                    args
                )  # reset the ffn since we're changing targets-- we're just pretraining the encoder.
                optimizer.param_groups.pop()  # remove ffn parameters
                optimizer.add_param_group({
                    'params': model.ffn.parameters(),
                    'lr': args.init_lr[1],
                    'weight_decay': args.weight_decay[1]
                })
                if args.cuda:
                    model.ffn.cuda()

            if args.gradual_unfreezing:
                if epoch % args.epochs_per_unfreeze == 0:
                    unfroze_layer = model.unfreeze_next(
                    )  # consider just stopping early after we have nothing left to unfreeze?
                    if unfroze_layer:
                        debug('Unfroze last frozen layer')

            n_iter = train(model=model,
                           data=train_data,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=(args.num_chunks > 1),
                           val_smiles=val_smiles if args.adversarial else None,
                           test_smiles=test_smiles
                           if args.adversarial or args.moe else None)
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(model=model,
                                  data=val_data,
                                  metric_func=metric_func,
                                  args=args,
                                  scaler=scaler,
                                  logger=logger)

            if args.dataset_type == 'bert_pretraining':
                if val_scores['features'] is not None:
                    debug(
                        f'Validation features rmse = {val_scores["features"]:.6f}'
                    )
                    writer.add_scalar('validation_features_rmse',
                                      val_scores['features'], n_iter)
                val_scores = [val_scores['vocab']]

            # Average validation score
            avg_val_score = np.nanmean(val_scores)
            debug(f'Validation {args.metric} = {avg_val_score:.6f}')
            writer.add_scalar(f'validation_{args.metric}', avg_val_score,
                              n_iter)

            if args.show_individual_scores:
                # Individual validation scores
                for task_name, val_score in zip(args.task_names, val_scores):
                    if task_name in desired_labels:
                        debug(
                            f'Validation {task_name} {args.metric} = {val_score:.6f}'
                        )
                        writer.add_scalar(
                            f'validation_{task_name}_{args.metric}', val_score,
                            n_iter)

            # Save model checkpoint if improved validation score, or always save it if unsupervised
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score or \
                    args.dataset_type == 'unsupervised':
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir, 'model.pt'), model,
                                scaler, features_scaler, args)

        if args.dataset_type == 'unsupervised':
            return [0]  # rest of this is meaningless when unsupervised

        # Evaluate on test set using model with best validation score
        info(
            f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}'
        )
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'),
                                cuda=args.cuda,
                                logger=logger)

        if args.split_test_by_overlap_dataset is not None:
            overlap_data = get_data(path=args.split_test_by_overlap_dataset,
                                    logger=logger)
            overlap_smiles = set(overlap_data.smiles())
            test_data_intersect, test_data_nonintersect = [], []
            for d in test_data.data:
                if d.smiles in overlap_smiles:
                    test_data_intersect.append(d)
                else:
                    test_data_nonintersect.append(d)
            test_data_intersect, test_data_nonintersect = MoleculeDataset(
                test_data_intersect), MoleculeDataset(test_data_nonintersect)
            for name, td in [('Intersect', test_data_intersect),
                             ('Nonintersect', test_data_nonintersect)]:
                test_preds = predict(model=model,
                                     data=td,
                                     args=args,
                                     scaler=scaler,
                                     logger=logger)
                test_scores = evaluate_predictions(
                    preds=test_preds,
                    targets=td.targets(),
                    metric_func=metric_func,
                    dataset_type=args.dataset_type,
                    args=args,
                    logger=logger)
                avg_test_score = np.nanmean(test_scores)
                info(
                    f'Model {model_idx} test {args.metric} for {name} = {avg_test_score:.6f}'
                )

        if len(
                test_data
        ) == 0:  # just get some garbage results without crashing; in this case we didn't care anyway
            test_preds, test_scores = sum_test_preds, [
                0 for _ in range(len(args.task_names))
            ]
        else:
            test_preds = predict(model=model,
                                 data=test_data,
                                 args=args,
                                 scaler=scaler,
                                 logger=logger)
            test_scores = evaluate_predictions(preds=test_preds,
                                               targets=test_targets,
                                               metric_func=metric_func,
                                               dataset_type=args.dataset_type,
                                               args=args,
                                               logger=logger)

        if args.maml:
            if sum_test_preds is None:
                sum_test_preds = np.zeros(np.array(test_preds).shape)

        if args.dataset_type == 'bert_pretraining':
            if test_preds['features'] is not None:
                sum_test_preds['features'] += np.array(test_preds['features'])
            sum_test_preds['vocab'] += np.array(test_preds['vocab'])
        else:
            sum_test_preds += np.array(test_preds)

        if args.dataset_type == 'bert_pretraining':
            if test_preds['features'] is not None:
                debug(
                    f'Model {model_idx} test features rmse = {test_scores["features"]:.6f}'
                )
                writer.add_scalar('test_features_rmse',
                                  test_scores['features'], 0)
            test_scores = [test_scores['vocab']]

        # Average test score
        avg_test_score = np.nanmean(test_scores)
        info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')
        writer.add_scalar(f'test_{args.metric}', avg_test_score, 0)

        if args.show_individual_scores:
            # Individual test scores
            for task_name, test_score in zip(args.task_names, test_scores):
                if task_name in desired_labels:
                    info(
                        f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}'
                    )
                    writer.add_scalar(f'test_{task_name}_{args.metric}',
                                      test_score, n_iter)

    # Evaluate ensemble on test set
    if args.dataset_type == 'bert_pretraining':
        avg_test_preds = {
            'features':
            (sum_test_preds['features'] / args.ensemble_size).tolist()
            if sum_test_preds['features'] is not None else None,
            'vocab': (sum_test_preds['vocab'] / args.ensemble_size).tolist()
        }
    else:
        avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    if len(test_data
           ) == 0:  # just return some garbage when we didn't want test data
        ensemble_scores = test_scores
    else:
        ensemble_scores = evaluate_predictions(preds=avg_test_preds,
                                               targets=test_targets,
                                               metric_func=metric_func,
                                               dataset_type=args.dataset_type,
                                               args=args,
                                               logger=logger)

    # Average ensemble score
    if args.dataset_type == 'bert_pretraining':
        if ensemble_scores['features'] is not None:
            info(
                f'Ensemble test features rmse = {ensemble_scores["features"]:.6f}'
            )
            writer.add_scalar('ensemble_test_features_rmse',
                              ensemble_scores['features'], 0)
        ensemble_scores = [ensemble_scores['vocab']]

    avg_ensemble_test_score = np.nanmean(ensemble_scores)
    info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')
    writer.add_scalar(f'ensemble_test_{args.metric}', avg_ensemble_test_score,
                      0)

    # Individual ensemble scores
    if args.show_individual_scores:
        for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
            info(
                f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}'
            )

    return ensemble_scores
Ejemplo n.º 9
0
def run_training(args: TrainArgs,
                 data: MoleculeDataset,
                 logger: Logger = None) -> Dict[str, List[float]]:
    """
    Loads data, trains a Chemprop model, and returns test scores for the model checkpoint with the highest validation score.

    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
                 loading data and training the Chemprop model.
    :param data: A :class:`~chemprop.data.MoleculeDataset` containing the data.
    :param logger: A logger to record output.
    :return: A dictionary mapping each metric in :code:`args.metrics` to a list of values for each task.

    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Set pytorch seed for random initial weights
    torch.manual_seed(args.pytorch_seed)

    # Split data
    debug(f"Splitting data with seed {args.seed}")
    # if args.separate_test_path:
    #    test_data = get_data(
    #        path=args.separate_test_path,
    #        args=args,
    #        features_path=args.separate_test_features_path,
    #        atom_descriptors_path=args.separate_test_atom_descriptors_path,
    #        bond_features_path=args.separate_test_bond_features_path,
    #        smiles_columns=args.smiles_columns,
    #        logger=logger,
    #    )
    # if args.separate_val_path:
    #    val_data = get_data(
    #        path=args.separate_val_path,
    #        args=args,
    #        features_path=args.separate_val_features_path,
    #        atom_descriptors_path=args.separate_val_atom_descriptors_path,
    #        bond_features_path=args.separate_val_bond_features_path,
    #        smiles_columns=args.smiles_columns,
    #        logger=logger,
    #    )

    # if args.separate_val_path and args.separate_test_path:
    #    train_data = data
    # elif args.separate_val_path:
    #    train_data, _, test_data = split_data(
    #        data=data,
    #        split_type=args.split_type,
    #        sizes=(0.8, 0.0, 0.2),
    #        seed=args.seed,
    #        num_folds=args.num_folds,
    #        args=args,
    #        logger=logger,
    #    )
    # elif args.separate_test_path:
    #    train_data, val_data, _ = split_data(
    #        data=data,
    #        split_type=args.split_type,
    #        sizes=(0.8, 0.2, 0.0),
    #        seed=args.seed,
    #        num_folds=args.num_folds,
    #        args=args,
    #        logger=logger,
    #    )
    # else:  # Default
    train_data, val_data, test_data = split_data(
        data=data,
        split_type=args.split_type,
        sizes=args.split_sizes,
        seed=args.seed,
        num_folds=args.num_folds,
        args=args,
        logger=logger,
    )

    if args.dataset_type == "classification":
        class_sizes = get_class_sizes(data)
        debug("Class sizes")
        for i, task_class_sizes in enumerate(class_sizes):
            debug(
                f"{args.task_names[i]} "
                f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}'
            )

    if args.save_smiles_splits:
        save_smiles_splits(
            data_path=args.data_path,
            save_dir=args.save_dir,
            task_names=args.task_names,
            features_path=args.features_path,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            smiles_columns=args.smiles_columns,
        )

    if args.features_scaling:
        features_scaler = train_data.normalize_features(replace_nan_token=0)
        val_data.normalize_features(features_scaler)
        test_data.normalize_features(features_scaler)
    else:
        features_scaler = None

    if args.atom_descriptor_scaling and args.atom_descriptors is not None:
        atom_descriptor_scaler = train_data.normalize_features(
            replace_nan_token=0, scale_atom_descriptors=True)
        val_data.normalize_features(atom_descriptor_scaler,
                                    scale_atom_descriptors=True)
        test_data.normalize_features(atom_descriptor_scaler,
                                     scale_atom_descriptors=True)
    else:
        atom_descriptor_scaler = None

    if args.bond_feature_scaling and args.bond_features_size > 0:
        bond_feature_scaler = train_data.normalize_features(
            replace_nan_token=0, scale_bond_features=True)
        val_data.normalize_features(bond_feature_scaler,
                                    scale_bond_features=True)
        test_data.normalize_features(bond_feature_scaler,
                                     scale_bond_features=True)
    else:
        bond_feature_scaler = None

    args.train_data_size = len(train_data)

    debug(
        f"Total size = {len(data):,} | "
        f"train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}"
    )

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    if args.dataset_type == "regression":
        debug("Fitting scaler")
        scaler = train_data.normalize_targets()
    else:
        scaler = None

    # Get loss function
    loss_func = get_loss_func(args)

    # Set up test set evaluation
    test_smiles, test_targets = test_data.smiles(), test_data.targets()
    if args.dataset_type == "multiclass":
        sum_test_preds = np.zeros(
            (len(test_smiles), args.num_tasks, args.multiclass_num_classes))
    else:
        sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Automatically determine whether to cache
    if len(data) <= args.cache_cutoff:
        set_cache_graph(True)
        num_workers = 0
    else:
        set_cache_graph(False)
        num_workers = args.num_workers

    # Create data loaders
    train_data_loader = MoleculeDataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        num_workers=num_workers,
        class_balance=args.class_balance,
        shuffle=True,
        seed=args.seed,
    )
    val_data_loader = MoleculeDataLoader(dataset=val_data,
                                         batch_size=args.batch_size,
                                         num_workers=num_workers)
    test_data_loader = MoleculeDataLoader(dataset=test_data,
                                          batch_size=args.batch_size,
                                          num_workers=num_workers)

    if args.class_balance:
        debug(
            f"With class_balance, effective train size = {train_data_loader.iter_size:,}"
        )

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f"model_{model_idx}")
        makedirs(save_dir)
        try:
            writer = SummaryWriter(log_dir=save_dir)
        except:
            writer = SummaryWriter(logdir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(
                f"Loading model {model_idx} from {args.checkpoint_paths[model_idx]}"
            )
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    logger=logger)
        else:
            debug(f"Building model {model_idx}")
            model = MoleculeModel(args)

        debug(model)
        debug(f"Number of parameters = {param_count(model):,}")
        if args.cuda:
            debug("Moving model to cuda")
        model = model.to(args.device)

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(
            os.path.join(save_dir, MODEL_FILE_NAME),
            model,
            scaler,
            features_scaler,
            atom_descriptor_scaler,
            bond_feature_scaler,
            args,
        )

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float("inf") if args.minimize_score else -float("inf")
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug(f"Epoch {epoch}")

            n_iter = train(
                model=model,
                data_loader=train_data_loader,
                loss_func=loss_func,
                optimizer=optimizer,
                scheduler=scheduler,
                args=args,
                n_iter=n_iter,
                logger=logger,
                writer=writer,
            )
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(
                model=model,
                data_loader=val_data_loader,
                num_tasks=args.num_tasks,
                metrics=args.metrics,
                dataset_type=args.dataset_type,
                scaler=scaler,
                logger=logger,
            )

            for metric, scores in val_scores.items():
                # Average validation score
                avg_val_score = np.nanmean(scores)
                debug(f"Validation {metric} = {avg_val_score:.6f}")
                writer.add_scalar(f"validation_{metric}", avg_val_score,
                                  n_iter)

                if args.show_individual_scores:
                    # Individual validation scores
                    for task_name, val_score in zip(args.task_names, scores):
                        debug(
                            f"Validation {task_name} {metric} = {val_score:.6f}"
                        )
                        writer.add_scalar(f"validation_{task_name}_{metric}",
                                          val_score, n_iter)

            # Save model checkpoint if improved validation score
            avg_val_score = np.nanmean(val_scores[args.metric])
            if (args.minimize_score and avg_val_score < best_score
                    or not args.minimize_score and avg_val_score > best_score):
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(
                    os.path.join(save_dir, MODEL_FILE_NAME),
                    model,
                    scaler,
                    features_scaler,
                    atom_descriptor_scaler,
                    bond_feature_scaler,
                    args,
                )

        # Evaluate on test set using model with best validation score
        info(
            f"Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}"
        )
        model = load_checkpoint(os.path.join(save_dir, MODEL_FILE_NAME),
                                device=args.device,
                                logger=logger)

        test_preds = predict(model=model,
                             data_loader=test_data_loader,
                             scaler=scaler)
        test_scores = evaluate_predictions(
            preds=test_preds,
            targets=test_targets,
            num_tasks=args.num_tasks,
            metrics=args.metrics,
            dataset_type=args.dataset_type,
            logger=logger,
        )

        if len(test_preds) != 0:
            sum_test_preds += np.array(test_preds)

        # Average test score
        for metric, scores in test_scores.items():
            avg_test_score = np.nanmean(scores)
            info(f"Model {model_idx} test {metric} = {avg_test_score:.6f}")
            writer.add_scalar(f"test_{metric}", avg_test_score, 0)

            if args.show_individual_scores:
                # Individual test scores
                for task_name, test_score in zip(args.task_names, scores):
                    info(
                        f"Model {model_idx} test {task_name} {metric} = {test_score:.6f}"
                    )
                    writer.add_scalar(f"test_{task_name}_{metric}", test_score,
                                      n_iter)
        writer.close()

    # Evaluate ensemble on test set
    avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    ensemble_scores = evaluate_predictions(
        preds=avg_test_preds,
        targets=test_targets,
        num_tasks=args.num_tasks,
        metrics=args.metrics,
        dataset_type=args.dataset_type,
        logger=logger,
    )

    for metric, scores in ensemble_scores.items():
        # Average ensemble score
        avg_ensemble_test_score = np.nanmean(scores)
        info(f"Ensemble test {metric} = {avg_ensemble_test_score:.6f}")

        # Individual ensemble scores
        if args.show_individual_scores:
            for task_name, ensemble_score in zip(args.task_names, scores):
                info(
                    f"Ensemble test {task_name} {metric} = {ensemble_score:.6f}"
                )

    # Optionally save test preds
    if args.save_preds:
        test_preds_dataframe = pd.DataFrame(
            data={"smiles": test_data.smiles()})

        for i, task_name in enumerate(args.task_names):
            test_preds_dataframe[task_name] = [
                pred[i] for pred in avg_test_preds
            ]

        test_preds_dataframe.to_csv(os.path.join(args.save_dir,
                                                 "test_preds.csv"),
                                    index=False)

    return ensemble_scores
def run_meta_training(args: TrainArgs, logger: Logger = None) -> List[float]:
    """
    Trains a model and returns test scores on the model checkpoint with the highest validation score.

    :param args: Arguments.
    :param logger: Logger.
    :return: A list of ensemble scores for each task.
    """
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    # Print command line
    debug('Command line')
    debug(f'python {" ".join(sys.argv)}')

    # Print args
    debug('Args')
    debug(args)

    # Save args
    args.save(os.path.join(args.save_dir, 'args.json'))

    # Set pytorch seed for random initial weights
    torch.manual_seed(args.pytorch_seed)

    # Get data
    debug('Loading data')
    args.task_names = args.target_columns or get_task_names(args.data_path)
    data = get_data(path=args.data_path, args=args, logger=logger)
    args.num_tasks = data.num_tasks()
    args.features_size = data.features_size()
    debug(f'Number of tasks = {args.num_tasks}')

    # Split data
    # debug(f'Splitting data with seed {args.seed}')
    # if args.separate_test_path:
    #     test_data = get_data(path=args.separate_test_path, args=args, features_path=args.separate_test_features_path, logger=logger)
    # if args.separate_val_path:
    #     val_data = get_data(path=args.separate_val_path, args=args, features_path=args.separate_val_features_path, logger=logger)

    # if args.separate_val_path and args.separate_test_path:
    #     train_data = data
    # elif args.separate_val_path:
    #     train_data, _, test_data = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.0, 0.2), seed=args.seed, args=args, logger=logger)
    # elif args.separate_test_path:
    #     train_data, val_data, _ = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
    # else:
    #     train_data, val_data, test_data = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, seed=args.seed, args=args, logger=logger)

    if args.dataset_type == 'classification':
        class_sizes = get_class_sizes(data)
        debug('Class sizes')
        for i, task_class_sizes in enumerate(class_sizes):
            debug(f'{args.task_names[i]} '
                  f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')

    # if args.save_smiles_splits:
    #     save_smiles_splits(
    #         train_data=train_data,
    #         val_data=val_data,
    #         test_data=test_data,
    #         data_path=args.data_path,
    #         save_dir=args.save_dir
    #     )

    # If this happens, then need to move this logic into the task data loader
    # when it creates the datasets! 
    # if args.features_scaling:
    #     features_scaler = train_data.normalize_features(replace_nan_token=0)
    #     val_data.normalize_features(features_scaler)
    #     test_data.normalize_features(features_scaler)
    # else:
    #     features_scaler = None

    # args.train_data_size = len(train_data)
    
    # debug(f'Total size = {len(data):,} | '
    #       f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')

    # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
    # if args.dataset_type == 'regression':
    #     debug('Fitting scaler')
    #     train_smiles, train_targets = train_data.smiles(), train_data.targets()
    #     scaler = StandardScaler().fit(train_targets)
    #     scaled_targets = scaler.transform(train_targets).tolist()
    #     train_data.set_targets(scaled_targets)
    # else:
    #     scaler = None

    # Get loss and metric functions
    loss_func = get_loss_func(args)
    metric_func = get_metric_func(metric=args.metric)

    # Set up test set evaluation
    # test_smiles, test_targets = test_data.smiles(), test_data.targets()
    # if args.dataset_type == 'multiclass':
    #     sum_test_preds = np.zeros((len(test_smiles), args.num_tasks, args.multiclass_num_classes))
    # else:
    #     sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))

    # Automatically determine whether to cache
    if len(data) <= args.cache_cutoff:
        cache = True
        num_workers = 0
    else:
        cache = False
        num_workers = args.num_workers

    # Set up MetaTaskDataLoaders, which takes care of task splits under the hood 
    # Set up task splits into T_tr, T_val, T_test

    assert args.chembl_assay_metadata_pickle_path is not None
    with open(args.chembl_assay_metadata_pickle_path +
            'chembl_128_assay_type_to_names.pickle', 'rb') as handle:
        chembl_128_assay_type_to_names = pickle.load(handle)
    with open(args.chembl_assay_metadata_pickle_path +
            'chembl_128_assay_name_to_type.pickle', 'rb') as handle:
        chembl_128_assay_name_to_type = pickle.load(handle)

    """ 
    Copy GSK implementation of task split 
    We have 5 Task types remaining
    ADME (A)
    Toxicity (T)
    Unassigned (U) 
    Binding (B)
    Functional (F)
    resulting in 902 tasks.

    For T_val, randomly select 10 B and F tasks
    For T_test, select another 10 B and F tasks and allocate all A, T, and U
    tasks to the test split.
    For T_train, allocate the remaining B and F tasks. 

    """
    import pdb; pdb.set_trace()
    T_val_num_BF_tasks = args.meta_split_sizes_BF[0]
    T_test_num_BF_tasks = args.meta_split_sizes_BF[1]
    T_val_idx = T_val_num_BF_tasks
    T_test_idx = T_val_num_BF_tasks + T_test_num_BF_tasks

    chembl_id_to_idx = {chembl_id: idx for idx, chembl_id in enumerate(args.task_names)}

    # Shuffle B and F tasks
    randomized_B_tasks = np.copy(chembl_128_assay_type_to_names['B'])
    np.random.shuffle(randomized_B_tasks)
    randomized_B_task_indices = [chembl_id_to_idx[assay] for assay in
            randomized_B_tasks]

    randomized_F_tasks = np.copy(chembl_128_assay_type_to_names['F'])
    np.random.shuffle(randomized_F_tasks)
    randomized_F_task_indices = [chembl_id_to_idx[assay] for assay in
            randomized_F_tasks]

    # Grab B and F indices for T_val
    T_val_B_task_indices = randomized_B_task_indices[:T_val_idx]
    T_val_F_task_indices = randomized_F_task_indices[:T_val_idx]

    # Grab B and F indices for T_test
    T_test_B_task_indices = randomized_B_task_indices[T_val_idx:T_test_idx]
    T_test_F_task_indices = randomized_F_task_indices[T_val_idx:T_test_idx]
    # Grab all A, T and U indices for T_test
    T_test_A_task_indices = [chembl_id_to_idx[assay] for assay in chembl_128_assay_type_to_names['A']]
    T_test_T_task_indices = [chembl_id_to_idx[assay] for assay in chembl_128_assay_type_to_names['T']]
    T_test_U_task_indices = [chembl_id_to_idx[assay] for assay in chembl_128_assay_type_to_names['U']]

    # Slot remaining BF tasks into T_tr
    T_tr_B_task_indices = randomized_B_task_indices[T_test_idx:]
    T_tr_F_task_indices = randomized_F_task_indices[T_test_idx:]

    T_tr = [0] * len(args.task_names)
    T_val = [0] * len(args.task_names)
    T_test = [0] * len(args.task_names)

    # Now make task bit vectors
    for idx_list in (T_tr_B_task_indices, T_tr_F_task_indices):
        for idx in idx_list:
            T_tr[idx] = 1

    for idx_list in (T_val_B_task_indices, T_val_F_task_indices):
        for idx in idx_list:
            T_val[idx] = 1

    for idx_list in (T_test_B_task_indices, T_test_F_task_indices, T_test_A_task_indices, T_test_T_task_indices, T_test_U_task_indices):
        for idx in idx_list:
            T_test[idx] = 1


    """
    Random task split for testing
    task_indices = list(range(len(args.task_names)))
    np.random.shuffle(task_indices)
    train_task_split, val_task_split, test_task_split = 0.9, 0, 0.1
    train_task_cutoff = int(len(task_indices) * train_task_split)
    train_task_idxs, test_task_idxs = [0] * len(task_indices), [0] * len(task_indices)
    for idx in task_indices[:train_task_cutoff]:
        train_task_idxs[idx] = 1
    for idx in task_indices[train_task_cutoff:]:
        test_task_idxs[idx] = 1
    """

    train_meta_task_data_loader = MetaTaskDataLoader(
            dataset=data,
            tasks=T_tr,
            sizes=args.meta_train_split_sizes,
            args=args,
            logger=logger)

    val_meta_task_data_loader = MetaTaskDataLoader(
            dataset=data,
            tasks=T_val,
            sizes=args.meta_test_split_sizes,
            args=args,
            logger=logger)

    test_meta_task_data_loader = MetaTaskDataLoader(
            dataset=data,
            tasks=T_test,
            sizes=args.meta_test_split_sizes,
            args=args,
            logger=logger)

    import pdb; pdb.set_trace()
    for meta_train_batch in train_meta_task_data_loader.tasks():
        for train_task in meta_train_batch:
            print('In inner loop')
            continue

    # Train ensemble of models
    for model_idx in range(args.ensemble_size):
        # Tensorboard writer
        save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
        makedirs(save_dir)
        try:
            writer = SummaryWriter(log_dir=save_dir)
        except:
            writer = SummaryWriter(logdir=save_dir)

        # Load/build model
        if args.checkpoint_paths is not None:
            debug(f'Loading model {model_idx} from {args.checkpoint_paths[model_idx]}')
            model = load_checkpoint(args.checkpoint_paths[model_idx], logger=logger)
        else:
            debug(f'Building model {model_idx}')
            model = MoleculeModel(args)

        debug(model)
        debug(f'Number of parameters = {param_count(model):,}')
        if args.cuda:
            debug('Moving model to cuda')
        model = model.to(args.device)

        # Ensure that model is saved in correct location for evaluation if 0 epochs
        save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)

        # Optimizers
        optimizer = build_optimizer(model, args)

        # Learning rate schedulers
        scheduler = build_lr_scheduler(optimizer, args)

        # Run training
        best_score = float('inf') if args.minimize_score else -float('inf')
        best_epoch, n_iter = 0, 0
        for epoch in trange(args.epochs):
            debug(f'Epoch {epoch}')

            n_iter = train(
                model=model,
                data_loader=train_data_loader,
                loss_func=loss_func,
                optimizer=optimizer,
                scheduler=scheduler,
                args=args,
                n_iter=n_iter,
                logger=logger,
                writer=writer
            )
            if isinstance(scheduler, ExponentialLR):
                scheduler.step()
            val_scores = evaluate(
                model=model,
                data_loader=val_data_loader,
                num_tasks=args.num_tasks,
                metric_func=metric_func,
                dataset_type=args.dataset_type,
                scaler=scaler,
                logger=logger
            )

            # Average validation score
            avg_val_score = np.nanmean(val_scores)
            debug(f'Validation {args.metric} = {avg_val_score:.6f}')
            writer.add_scalar(f'validation_{args.metric}', avg_val_score, n_iter)

            if args.show_individual_scores:
                # Individual validation scores
                for task_name, val_score in zip(args.task_names, val_scores):
                    debug(f'Validation {task_name} {args.metric} = {val_score:.6f}')
                    writer.add_scalar(f'validation_{task_name}_{args.metric}', val_score, n_iter)

            # Save model checkpoint if improved validation score
            if args.minimize_score and avg_val_score < best_score or \
                    not args.minimize_score and avg_val_score > best_score:
                best_score, best_epoch = avg_val_score, epoch
                save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)        

        # Evaluate on test set using model with best validation score
        info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
        model = load_checkpoint(os.path.join(save_dir, 'model.pt'), device=args.device, logger=logger)
        
        test_preds = predict(
            model=model,
            data_loader=test_data_loader,
            scaler=scaler
        )
        test_scores = evaluate_predictions(
            preds=test_preds,
            targets=test_targets,
            num_tasks=args.num_tasks,
            metric_func=metric_func,
            dataset_type=args.dataset_type,
            logger=logger
        )

        if len(test_preds) != 0:
            sum_test_preds += np.array(test_preds)

        # Average test score
        avg_test_score = np.nanmean(test_scores)
        info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')
        writer.add_scalar(f'test_{args.metric}', avg_test_score, 0)

        if args.show_individual_scores:
            # Individual test scores
            for task_name, test_score in zip(args.task_names, test_scores):
                info(f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}')
                writer.add_scalar(f'test_{task_name}_{args.metric}', test_score, n_iter)
        writer.close()

    # Evaluate ensemble on test set
    avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()

    ensemble_scores = evaluate_predictions(
        preds=avg_test_preds,
        targets=test_targets,
        num_tasks=args.num_tasks,
        metric_func=metric_func,
        dataset_type=args.dataset_type,
        logger=logger
    )

    # Average ensemble score
    avg_ensemble_test_score = np.nanmean(ensemble_scores)
    info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')

    # Individual ensemble scores
    if args.show_individual_scores:
        for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
            info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}')

    return ensemble_scores