示例#1
0
def get_dataset(cfg, trans, train=True):
    if not train:
        sub_list = cfg['sub_list_val']
        batch_size = 1
        shuffling = False
    else:
        sub_list = cfg['sub_list_train']
        batch_size = int(cfg['batch_size'])
        shuffling = cfg['shuffling']

    dataset = ds.HCP20Dataset(sub_list,
                              cfg['dataset_dir'],
                              same_size=cfg['same_size'],
                              transform=transforms.Compose(trans),
                              return_edges=cfg['return_edges'],
                              load_one_full_subj=False)

    dataloader = gDataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=shuffling,
                             num_workers=int(cfg['n_workers']),
                             pin_memory=True)

    print("Dataset %s loaded, found %d samples" %
          (cfg['dataset'], len(dataset)))
    return dataset, dataloader
示例#2
0
    compound_config = config_dict['compound']
    protein_config = config_dict['protein']

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    print('used device ', device)

    print('Creating model....')
    net = CPIGnnCnnModel(compound_config, protein_config).to(device)
    net.double()

    print('loading trained model')
    net.load_state_dict(
        torch.load(config.get_value('model'), map_location=device))

    print('Loading testing data ....')
    test_dataset = create_dataset(root, config.get_value('test'),
                                  compound_config, protein_config)
    test_dataloader = gDataLoader(test_dataset,
                                  batch_size=batch_size,
                                  shuffle=False)

    output_file = None
    if config.get_value('output'):
        output_file = config.get_value('output')
    explain_cpi_prediction(test_dataloader,
                           net,
                           device,
                           max_length=1000,
                           file_name=output_file)
    # a bit of loading speed, but we are sure we load correctly streamlines
    cfg['same_size'] = False

    # check available memory to decide how many streams sample
    curr_device = torch.cuda.current_device()
    cfg['fixed_size'] = get_max_batchsize(curr_device)

    dataset = TractDataset(trk_fn,
                           transform=TestSampling(cfg['fixed_size']),
                           return_edges=True,
                           split_obj=True,
                           same_size=cfg['same_size'])

    dataloader = gDataLoader(dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0,
                             pin_memory=True)

    classifier = get_model(cfg)

    if DEVICE == 'cuda':
        torch.cuda.set_device(DEVICE)
        torch.cuda.current_device()

    if cfg['weights_path'] == '':
        cfg['weights_path'] = glob.glob(cfg['exp_path'] + '/models/best*')[0]
    state = torch.load(cfg['weights_path'], map_location=DEVICE)

    classifier.load_state_dict(state)
    classifier.to(DEVICE)
示例#4
0
    print('dropout', protein_config.drop_out)
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    print('used device ', device)

    print('Creating model....')
    net = CPIGnnCnnModel(compound_config, protein_config).to(device)
    net = net.double()

    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

    print('Loading training data ....')
    train_dataset = create_dataset(root, config.get_value('train'),
                                   compound_config, protein_config)
    train_dataloader = gDataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)

    val_dataloader = None
    training_stopper = None
    if config.get_value('val'):
        print('Loading validation data ....')
        training_stopper = EarlyStopping(patience=10,
                                         verbose=True,
                                         delta=0,
                                         path=config.get_value('output'))
        val_dataset = create_dataset(root, config.get_value('val'),
                                     compound_config, protein_config)
        val_dataloader = gDataLoader(val_dataset,
                                     batch_size=batch_size,
                                     shuffle=False)