Exemple #1
0
def contact_map_distribution(max_len):
    d = ProteinnetDataset(get_data_path(), 'train')
    c = Counter()
    for row in tqdm.tqdm(d):
        contact_map = row[-2][:max_len, :max_len].flatten()
        c.update(contact_map)
    return c
Exemple #2
0
def protein_modification_distribution(max_len):
    d = ProteinModificationDataset(get_data_path(), 'train')
    c = Counter()
    for row in tqdm.tqdm(d):
        mod_indic = row[-1]
        c.update(mod_indic[:max_len])
    return c
Exemple #3
0
def binding_site_distribution(max_len):
    d = BindingSiteDataset(get_data_path(), 'train')
    c = Counter()
    for row in tqdm.tqdm(d):
        site_indic = row[-1]
        c.update(site_indic[:max_len])

    return c
    args = parser.parse_args()
    print(args)

    if args.model_version and args.model_dir:
        raise ValueError('Cannot specify both model version and directory')

    if args.num_sequences is not None and not args.shuffle:
        print(
            'WARNING: You are using a subset of sequences and you are not shuffling the data. This may result '
            'in a skewed sample.')
    cuda = not args.no_cuda

    torch.manual_seed(args.seed)

    if args.dataset == 'proteinnet':
        dataset = ProteinnetDataset(get_data_path(), 'train')
    elif args.dataset == 'secondary':
        dataset = SecondaryStructureDataset(get_data_path(), 'train')
    elif args.dataset == 'binding_sites':
        dataset = BindingSiteDataset(get_data_path(), 'train')
    elif args.dataset == 'protein_modifications':
        dataset = ProteinModificationDataset(get_data_path(), 'train')
    else:
        raise ValueError(f"Invalid dataset id: {args.dataset}")

    if not args.num_sequences:
        raise NotImplementedError

    if args.model == 'bert':
        if args.model_dir:
            model_version = args.model_dir
Exemple #5
0
def run_train(task: str,
              num_hidden_layers: int,
              one_vs_all_label: str = None,
              attention_probe: bool = False,
              label_scheme: str = None,
              learning_rate: float = 1e-4,
              batch_size: int = 1024,
              num_train_epochs: int = 10,
              num_log_iter: int = 20,
              fp16: bool = False,
              warmup_steps: int = 10000,
              gradient_accumulation_steps: int = 1,
              loss_scale: int = 0,
              max_grad_norm: float = 1.0,
              exp_name: typing.Optional[str] = None,
              log_dir: str = './logs',
              eval_freq: int = 1,
              save_freq: typing.Union[int, str] = 1,
              no_cuda: bool = False,
              seed: int = 42,
              local_rank: int = -1,
              num_workers: int = 0,
              debug: bool = False,
              log_level: typing.Union[str, int] = logging.INFO,
              patience: int = -1,
              max_seq_len: typing.Optional[int] = None) -> None:
    # SETUP AND LOGGING CODE #
    input_args = locals()
    device, n_gpu, is_master = utils.setup_distributed(local_rank, no_cuda)

    data_dir = get_data_path()
    output_dir = data_dir / 'probing'
    exp_dir = f'{(exp_name + "_") if exp_name else ""}{task}_{(one_vs_all_label + "_") if one_vs_all_label else ""}' \
              f'{"attn_" if attention_probe else ""}{num_hidden_layers}'
    save_path = Path(output_dir) / exp_dir

    if is_master:
        # save all the hidden parameters.
        save_path.mkdir(parents=True, exist_ok=True)
        with (save_path / 'args.json').open('w') as f:
            json.dump(input_args, f)

    utils.barrier_if_distributed()
    utils.setup_logging(local_rank, save_path, log_level)
    utils.set_random_seeds(seed, n_gpu)

    if task == 'secondary':
        num_labels = 2
        if attention_probe:
            model = ProteinBertForLinearSequenceToSequenceProbingFromAttention.from_pretrained(
                'bert-base',
                num_hidden_layers=num_hidden_layers,
                num_labels=num_labels)
        else:
            model = ProteinBertForLinearSequenceToSequenceProbing.from_pretrained(
                'bert-base',
                num_hidden_layers=num_hidden_layers,
                num_labels=num_labels)
        if label_scheme == 'ss4':
            label = int(one_vs_all_label)
        else:
            label = one_vs_all_label
        train_dataset = SecondaryStructureOneVsAllDataset(
            data_dir, 'train', label_scheme, label)
        valid_dataset = SecondaryStructureOneVsAllDataset(
            data_dir, 'valid', label_scheme, label)
    elif task == 'binding_sites':
        num_labels = 2
        if attention_probe:
            model = ProteinBertForLinearSequenceToSequenceProbingFromAttention.from_pretrained(
                'bert-base',
                num_hidden_layers=num_hidden_layers,
                num_labels=num_labels)
        else:
            model = ProteinBertForLinearSequenceToSequenceProbing.from_pretrained(
                'bert-base',
                num_hidden_layers=num_hidden_layers,
                num_labels=num_labels)
        train_dataset = BindingSiteDataset(data_dir, 'train')
        valid_dataset = BindingSiteDataset(data_dir, 'valid')
    elif task == 'contact_map':
        num_labels = 2
        if attention_probe:
            model = ProteinBertForContactPredictionFromAttention.from_pretrained(
                'bert-base', num_hidden_layers=num_hidden_layers)
        else:
            model = ProteinBertForContactProbing.from_pretrained(
                'bert-base', num_hidden_layers=num_hidden_layers)
        train_dataset = ProteinnetDataset(data_dir,
                                          'train',
                                          max_seq_len=max_seq_len)
        valid_dataset = ProteinnetDataset(data_dir,
                                          'valid',
                                          max_seq_len=max_seq_len)
    else:
        raise NotImplementedError

    model = model.to(device)
    optimizer = utils.setup_optimizer(model, learning_rate)
    # viz = visualization.get(log_dir, exp_dir, local_rank, debug=debug)
    # viz.log_config(input_args)
    # viz.log_config(model.config.to_dict())
    # viz.watch(model)
    viz = None

    train_loader = utils.setup_loader(train_dataset, batch_size, local_rank,
                                      n_gpu, gradient_accumulation_steps,
                                      num_workers)
    valid_loader = utils.setup_loader(valid_dataset, batch_size, local_rank,
                                      n_gpu, gradient_accumulation_steps,
                                      num_workers)

    num_train_optimization_steps = utils.get_num_train_optimization_steps(
        train_dataset, batch_size, num_train_epochs)

    logger.info(f"device: {device} "
                f"n_gpu: {n_gpu}, "
                f"distributed_training: {local_rank != -1}, "
                f"16-bits training: {fp16}")

    runner = BackwardRunner(model, optimizer, gradient_accumulation_steps,
                            device, n_gpu, fp16, local_rank, max_grad_norm,
                            warmup_steps, num_train_optimization_steps)

    runner.initialize_fp16()

    start_epoch = 0
    runner.initialize_distributed_model()

    num_train_optimization_steps = utils.get_num_train_optimization_steps(
        train_dataset, batch_size, num_train_epochs)
    is_master = local_rank in (-1, 0)

    if isinstance(save_freq, str) and save_freq != 'improvement':
        raise ValueError(
            f"Only recongized string value for save_freq is 'improvement'"
            f", received: {save_freq}")

    if save_freq == 'improvement' and eval_freq <= 0:
        raise ValueError(
            "Cannot set save_freq to 'improvement' and eval_freq < 0")

    num_trainable_parameters = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Batch size = %d", batch_size)
    logger.info("  Num epochs = %d", num_train_epochs)
    logger.info("  Num train steps = %d", num_train_optimization_steps)
    logger.info("  Num parameters = %d", num_trainable_parameters)

    best_val_loss = float('inf')
    num_evals_no_improvement = 0

    def do_save(epoch_id: int, num_evals_no_improvement: int) -> bool:
        if not is_master:
            return False
        if isinstance(save_freq, int):
            return ((epoch_id + 1) % save_freq == 0) or ((epoch_id + 1)
                                                         == num_train_epochs)
        else:
            return num_evals_no_improvement == 0

    utils.barrier_if_distributed()

    metrics = ['accuracy', 'precision', 'recall', 'f1']
    metric_functions = [accuracy, precision, recall, f1]

    # ACTUAL TRAIN/EVAL LOOP #
    with utils.wrap_cuda_oom_error(local_rank, batch_size, n_gpu,
                                   gradient_accumulation_steps):
        for epoch_id in range(start_epoch, num_train_epochs):
            run_train_epoch(epoch_id, train_loader, runner, viz, num_log_iter,
                            gradient_accumulation_steps)
            if eval_freq > 0 and (epoch_id + 1) % eval_freq == 0:
                val_loss, metric = run_valid_epoch(epoch_id, valid_loader,
                                                   runner, viz, is_master)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    num_evals_no_improvement = 0
                    if task == 'contact_map':
                        outputs, seq_lens = run_eval_epoch(
                            valid_loader, runner, get_sequence_lengths=True)
                    else:
                        outputs = run_eval_epoch(valid_loader,
                                                 runner,
                                                 get_sequence_lengths=False)
                    target = [el['target'] for el in outputs]
                    prediction = [el['prediction'] for el in outputs]

                    if task == 'contact_map':
                        # Reshape 2d to 1d
                        # Shape batch_size, seq_len, seq_len
                        prediction = [
                            torch.tensor(prediction_matrix).view(-1,
                                                                 2).tolist()
                            for prediction_matrix in prediction
                        ]
                        target = [
                            torch.tensor(target_matrix).view(-1).tolist()
                            for target_matrix in target
                        ]

                    metrics_to_save = {
                        name: metric(target, prediction)
                        for name, metric in zip(metrics, metric_functions)
                    }
                    if task == 'contact_map':
                        ks = [int(round(seq_len / 5)) for seq_len in seq_lens]
                        metrics_to_save['precision_at_k'] = precision_at_ks(
                            ks, target, prediction)
                    elif task == 'binding_sites':
                        seq_lens = []
                        for target_array in target:
                            mask = target_array != -1
                            seq_lens.append(mask.sum())
                        ks = [int(round(seq_len / 20)) for seq_len in seq_lens]
                        metrics_to_save['precision_at_k'] = precision_at_ks(
                            ks, target, prediction)
                    print(metrics_to_save)
                    metrics_to_save['loss'] = val_loss
                else:
                    num_evals_no_improvement += 1

            # Save trained model
            if do_save(epoch_id, num_evals_no_improvement):
                logger.info("** ** * Saving trained model ** ** * ")
                # Only save the model itself
                runner.save_state(save_path, epoch_id)
                logger.info(f"Saving model checkpoint to {save_path}")

            utils.barrier_if_distributed()
            if patience > 0 and num_evals_no_improvement >= patience:
                logger.info(
                    f"Finished training at epoch {epoch_id} because no "
                    f"improvement for {num_evals_no_improvement} epochs.")
                logger.log(35, f"Best Val Loss: {best_val_loss}")
                if local_rank != -1:
                    # If you're distributed, raise this error. It sends a signal to
                    # the master process which lets it kill other processes and terminate
                    # without actually reporting an error. See utils/distributed_utils.py
                    # for the signal handling code.
                    raise errors.EarlyStopping
                else:
                    break
    logger.info(f"Finished training after {num_train_epochs} epochs.")
    if best_val_loss != float('inf'):
        logger.log(35, f"Best Val Loss: {best_val_loss}")

    with open(save_path / 'results.json', 'w') as outfile:
        json.dump(metrics_to_save, outfile)

    del model
Exemple #6
0
        diffs = [scores[i] - scores[i - 1] for i in range(1, 12)]
        ax[i].bar(list(range(11)), diffs)
        ax[i].tick_params(labelsize=6)
        ax[i].set_ylabel(feature.replace('Contact Map', 'Contact'), fontsize=8)
        ax[i].yaxis.tick_right()
    plt.xticks(list(range(11)), list(range(2, 13)))
    plt.xlabel('Layer', fontsize=8)
    fname = report_dir / f'multichart_layer_delta_probing.{filetype}'
    print('Saving', fname)
    plt.savefig(fname, format=filetype, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":

    data_path = get_data_path()

    features = []
    feature_scores = []

    # Probing sec struct results
    ss_cds = [0, 1, 2]
    ss_names = ss4_names
    for ss_cd in ss_cds:
        feature = ss_names[ss_cd]
        features.append(feature)
        scores = [0] * 12
        for num_layers in list(range(1, 13)):
            fname = data_path / 'probing' / f'secondary_{ss_cd}_{num_layers}/results.json'
            try:
                with open(fname) as infile: