Exemple #1
0
def model_training(data_list_train, data_list_test, epochs, acc_epoch, acc_epoch2, save_model_epochs, validation_epoch, batchsize, logfilename, load_checkpoint= None):
        
    #logging
    logging.basicConfig(level=logging.DEBUG, filename='./logfiles/'+logfilename, filemode="w+",
                        format="%(message)s")
    trainloader = DataListLoader(data_list_train, batch_size=batchsize, shuffle=True)
    testloader = DataListLoader(data_list_test, batch_size=batchsize, shuffle=True)
    device = torch.device('cuda')
    complete_net = completeNet()
    complete_net = DataParallel(complete_net)
    complete_net = complete_net.to(device)
    
    #train parameters
    weights = [10, 1]
    optimizer = torch.optim.Adam(complete_net.parameters(), lr=0.001, weight_decay=0.001)

    #resume training
    initial_epoch=1
    if load_checkpoint!=None:
        checkpoint = torch.load(load_checkpoint)
        complete_net.load_state_dict(checkpoint['model_state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initial_epoch = checkpoint['epoch']+1
        loss = checkpoint['loss']
    
    complete_net.train()

    for epoch in range(initial_epoch, epochs+1):
        epoch_total=0
        epoch_total_ones= 0
        epoch_total_zeros= 0
        epoch_correct=0
        epoch_correct_ones= 0
        epoch_correct_zeros= 0
        running_loss= 0
        batches_num=0         
        for batch in trainloader:
            batch_total=0
            batch_total_ones= 0
            batch_total_zeros= 0
            batch_correct= 0
            batch_correct_ones= 0
            batch_correct_zeros= 0
            batches_num+=1
            # Forward-Backpropagation
            output, output2, ground_truth, ground_truth2, det_num, tracklet_num= complete_net(batch)
            optimizer.zero_grad()
            loss = weighted_binary_cross_entropy(output, ground_truth, weights)
            loss.backward()
            optimizer.step()
            ##Accuracy 
            if epoch%acc_epoch==0 and epoch!=0:
                # Hungarian method, clean up
                cleaned_output= hungarian(output2, ground_truth2, det_num, tracklet_num)
                batch_total += cleaned_output.size(0)
                ones= torch.tensor([1 for x in cleaned_output]).to(device)
                zeros = torch.tensor([0 for x in cleaned_output]).to(device)
                batch_total_ones += (cleaned_output == ones).sum().item()
                batch_total_zeros += (cleaned_output == zeros).sum().item()
                batch_correct += (cleaned_output == ground_truth2).sum().item()
                temp1 = (cleaned_output == ground_truth2)
                temp2 = (cleaned_output == ones)
                batch_correct_ones += (temp1 & temp2).sum().item()
                temp3 = (cleaned_output == zeros)
                batch_correct_zeros += (temp1 & temp3).sum().item()
                epoch_total += batch_total
                epoch_total_ones += batch_total_ones
                epoch_total_zeros += batch_total_zeros
                epoch_correct += batch_correct
                epoch_correct_ones += batch_correct_ones
                epoch_correct_zeros += batch_correct_zeros
            if loss.item()!=loss.item():
                print("Error")
                break
            if batch_total_ones != 0 and batch_total_zeros != 0 and epoch%acc_epoch==0 and epoch!=0:
                print('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, batches_num, loss.item(), 100 * batch_correct / batch_total, 100 * batch_correct_ones / batch_total_ones,
                       100 * batch_correct_zeros / batch_total_zeros))
                logging.info('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, batches_num, loss.item(), 100 * batch_correct / batch_total, 100 * batch_correct_ones / batch_total_ones,
                       100 * batch_correct_zeros / batch_total_zeros))
            else:
                print('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f |' %
                        (epoch, batches_num, loss.item()))
                logging.info('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f |' %
                        (epoch, batches_num, loss.item()))
            running_loss += loss.item()
        if loss.item()!=loss.item():
                print("Error")
                break
        if epoch_total_ones!=0 and epoch_total_zeros!=0 and epoch%acc_epoch==0 and epoch!=0:
            print('Epoch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                          epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
            logging.info('Epoch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                          epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
        else:
            print('Epoch: [%d] | Training_Loss: %.3f |' %
                        (epoch, running_loss / batches_num))
            logging.info('Epoch: [%d] | Training_Loss: %.3f |' %
                        (epoch, running_loss / batches_num))
        # save model
        if epoch%save_model_epochs==0 and epoch!=0:
            torch.save({ 
                        'epoch': epoch,
                        'model_state_dict': complete_net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': running_loss,
                        }, './models/epoch_'+str(epoch)+'.pth')

        #validation
        if epoch%validation_epoch==0 and epoch!=0:
            with torch.no_grad():
                epoch_total=0
                epoch_total_ones= 0
                epoch_total_zeros= 0
                epoch_correct=0
                epoch_correct_ones= 0
                epoch_correct_zeros= 0
                running_loss= 0
                batches_num=0
                for batch in testloader:
                    batch_total=0
                    batch_total_ones= 0
                    batch_total_zeros= 0
                    batch_correct= 0
                    batch_correct_ones= 0
                    batch_correct_zeros= 0
                    batches_num+=1
                    output, output2, ground_truth, ground_truth2, det_num, tracklet_num = complete_net(batch)
                    loss = weighted_binary_cross_entropy(output, ground_truth, weights)
                    running_loss += loss.item()
                    ##Accuracy 
                    if epoch%acc_epoch2==0 and epoch!=0:
                        # Hungarian method, clean up
                        cleaned_output= hungarian(output2, ground_truth2, det_num, tracklet_num)
                        batch_total += cleaned_output.size(0)
                        ones= torch.tensor([1 for x in cleaned_output]).to(device)
                        zeros = torch.tensor([0 for x in cleaned_output]).to(device)
                        batch_total_ones += (cleaned_output == ones).sum().item()
                        batch_total_zeros += (cleaned_output == zeros).sum().item()
                        batch_correct += (cleaned_output == ground_truth2).sum().item()
                        temp1 = (cleaned_output == ground_truth2)
                        temp2 = (cleaned_output == ones)
                        batch_correct_ones += (temp1 & temp2).sum().item()
                        temp3 = (cleaned_output == zeros)
                        batch_correct_zeros += (temp1 & temp3).sum().item()
                        epoch_total += batch_total
                        epoch_total_ones += batch_total_ones
                        epoch_total_zeros += batch_total_zeros
                        epoch_correct += batch_correct
                        epoch_correct_ones += batch_correct_ones
                        epoch_correct_zeros += batch_correct_zeros
                if epoch_total_ones!=0 and epoch_total_zeros!=0 and epoch%acc_epoch2==0 and epoch!=0:
                    print('Epoch: [%d] | Validation_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                                (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                                    epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
                    logging.info('Epoch: [%d] | Validation_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                                (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                                    epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
                else:
                    print('Epoch: [%d] | Validation_Loss: %.3f |' %
                                (epoch, running_loss / batches_num))
                    logging.info('Epoch: [%d] | Validation_Loss: %.3f |' %
                                (epoch, running_loss / batches_num))
]

#print('Can we use GPU? ',torch.cuda.is_available())
# Select which GPUs we can see
#os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#print(device.current_device())
myGCN = multiViewGCN(2, 4, device)

# Use multiple GPUs if we can
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    myGCN = DataParallel(myGCN)

optimizer = torch.optim.Adam(myGCN.parameters(),
                             lr=0.0005)  #, weight_decay=5e-4)

nEpochs = 20


def train():
    myGCN.train()
    loss_all = 0
    loss_func = torch.nn.CrossEntropyLoss()

    for data0, data1, data2 in zip(train_loader[0], train_loader[1],
                                   train_loader[2]):
        data0 = data0.to(device)
        data1 = data1.to(device)
        data2 = data2.to(device)
Exemple #3
0
def train_cross_validation(model_cls,
                           dataset,
                           dropout=0.0,
                           lr=1e-3,
                           weight_decay=1e-2,
                           num_epochs=200,
                           n_splits=10,
                           use_gpu=True,
                           dp=False,
                           ddp=False,
                           comment='',
                           tb_service_loc='192.168.192.57:6007',
                           batch_size=1,
                           num_workers=0,
                           pin_memory=False,
                           cuda_device=None,
                           tb_dir='runs',
                           model_save_dir='saved_models',
                           res_save_dir='res',
                           fold_no=None,
                           saved_model_path=None,
                           device_ids=None,
                           patience=20,
                           seed=None,
                           fold_seed=None,
                           save_model=False,
                           is_reg=True,
                           live_loss=True,
                           domain_cls=True,
                           final_cls=True):
    """
    :type fold_seed: int
    :param live_loss: bool
    :param is_reg: bool
    :param save_model: bool
    :param seed:
    :param patience: for early stopping
    :param device_ids: for ddp
    :param saved_model_path:
    :param fold_no: int
    :param ddp_port: str
    :param ddp: DDP
    :param cuda_device: list of int
    :param pin_memory: bool, DataLoader args
    :param num_workers: int, DataLoader args
    :param model_cls: pytorch Module cls
    :param dataset: instance
    :param dropout: float
    :param lr: float
    :param weight_decay:
    :param num_epochs:
    :param n_splits: number of kFolds
    :param use_gpu: bool
    :param dp: bool
    :param comment: comment in the logs, to filter runs in tensorboard
    :param tb_service_loc: tensorboard service location
    :param batch_size: Dataset args not DataLoader
    :return:
    """
    saved_args = locals()
    seed = int(time.time() % 1e4 * 1e5) if seed is None else seed
    saved_args['random_seed'] = seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

    model_name = model_cls.__name__

    if not cuda_device:
        if device_ids and dp:
            device = device_ids[0]
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    else:
        device = cuda_device

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None
                                       and dp) else device_count

    batch_size = batch_size * device_count

    # TensorBoard
    log_dir_base = get_model_log_dir(comment, model_name)
    if tb_service_loc is not None:
        print("TensorBoard available at http://{1}/#scalars&regexInput={0}".
              format(log_dir_base, tb_service_loc))
    else:
        print("Please set up TensorBoard")

    # model
    criterion = nn.NLLLoss()

    print("Training {0} {1} models for cross validation...".format(
        n_splits, model_name))
    # 1
    # folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0
    # 2
    # folds = GroupKFold(n_splits=n_splits)
    # iter = folds.split(np.zeros(len(dataset)), groups=dataset.data.site_id)
    # 4
    # folds = StratifiedKFold(n_splits=n_splits, random_state=fold_seed, shuffle=True if fold_seed else False)
    # iter = folds.split(np.zeros(len(dataset)), dataset.data.y.numpy(), groups=dataset.data.subject_id)
    # 5
    fold = 0
    iter = multi_site_cv_split(dataset.data.y,
                               dataset.data.site_id,
                               dataset.data.subject_id,
                               n_splits,
                               random_state=fold_seed,
                               shuffle=True if fold_seed else False)

    for train_idx, val_idx in tqdm_notebook(iter, desc='CV', leave=False):
        fold += 1
        liveloss = PlotLosses() if live_loss else None

        # for a specific fold
        if fold_no is not None:
            if fold != fold_no:
                continue

        writer = SummaryWriter(log_dir=osp.join('runs', log_dir_base +
                                                str(fold)))
        model_save_dir = osp.join('saved_models', log_dir_base + str(fold))

        print("creating dataloader tor fold {}".format(fold))

        train_dataset, val_dataset = norm_train_val(dataset, train_idx,
                                                    val_idx)

        model = model_cls(writer)

        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=batch_size,
                                      collate_fn=lambda data_list: data_list,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory)
        val_dataloader = DataLoader(val_dataset,
                                    shuffle=False,
                                    batch_size=batch_size,
                                    collate_fn=lambda data_list: data_list,
                                    num_workers=num_workers,
                                    pin_memory=pin_memory)

        if fold == 1 or fold_no is not None:
            print(model)
            writer.add_text('model_summary', model.__repr__())
            writer.add_text('training_args', str(saved_args))

        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=lr,
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=weight_decay,
                                      amsgrad=False)
        # scheduler_reduce = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=10,
                                           total_epoch=5)
        # scheduler = scheduler_reduce
        # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
        if dp and use_gpu:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = DataParallel(model, device_ids=device_ids)
        elif use_gpu:
            model = model.to(device)

        if saved_model_path is not None:
            model.load_state_dict(torch.load(saved_model_path))

        best_map, patience_counter, best_score = 0.0, 0, np.inf
        for epoch in tqdm_notebook(range(1, num_epochs + 1),
                                   desc='Epoch',
                                   leave=False):
            logs = {}

            # scheduler.step(epoch=epoch, metrics=best_score)

            for phase in ['train', 'validation']:

                if phase == 'train':
                    model.train()
                    dataloader = train_dataloader
                else:
                    model.eval()
                    dataloader = val_dataloader

                # Logging
                running_total_loss = 0.0
                running_corrects = 0
                running_reg_loss = 0.0
                running_nll_loss = 0.0
                epoch_yhat_0, epoch_yhat_1 = torch.tensor([]), torch.tensor([])
                epoch_label, epoch_predicted = torch.tensor([]), torch.tensor(
                    [])

                logging_hist = True if phase == 'train' else False  # once per epoch
                for data_list in tqdm_notebook(dataloader,
                                               desc=phase,
                                               leave=False):

                    # TODO: check devices
                    if dp:
                        data_list = to_cuda(data_list,
                                            (device_ids[0] if device_ids
                                             is not None else 'cuda'))

                    y_hat, domain_yhat, reg = model(data_list)

                    y = torch.tensor([],
                                     dtype=dataset.data.y.dtype,
                                     device=device)
                    domain_y = torch.tensor([],
                                            dtype=dataset.data.site_id.dtype,
                                            device=device)
                    for data in data_list:
                        y = torch.cat([y, data.y.view(-1).to(device)])
                        domain_y = torch.cat(
                            [domain_y,
                             data.site_id.view(-1).to(device)])

                    loss = criterion(y_hat, y)
                    domain_loss = criterion(domain_yhat, domain_y)
                    # domain_loss = -1e-7 * domain_loss
                    # print(domain_loss.item())
                    if domain_cls:
                        total_loss = domain_loss
                        _, predicted = torch.max(domain_yhat, 1)
                        label = domain_y
                    if final_cls:
                        total_loss = loss
                        _, predicted = torch.max(y_hat, 1)
                        label = y
                    if domain_cls and final_cls:
                        total_loss = (loss + domain_loss).sum()
                        _, predicted = torch.max(y_hat, 1)
                        label = y

                    if is_reg:
                        total_loss += reg.sum()

                    if phase == 'train':
                        # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True))
                        optimizer.zero_grad()
                        total_loss.backward()
                        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                        optimizer.step()

                    running_nll_loss += loss.item()
                    running_total_loss += total_loss.item()
                    running_reg_loss += reg.sum().item()
                    running_corrects += (predicted == label).sum().item()

                    epoch_yhat_0 = torch.cat(
                        [epoch_yhat_0, y_hat[:, 0].detach().view(-1).cpu()])
                    epoch_yhat_1 = torch.cat(
                        [epoch_yhat_1, y_hat[:, 1].detach().view(-1).cpu()])
                    epoch_label = torch.cat(
                        [epoch_label,
                         label.detach().float().view(-1).cpu()])
                    epoch_predicted = torch.cat([
                        epoch_predicted,
                        predicted.detach().float().view(-1).cpu()
                    ])

                # precision = sklearn.metrics.precision_score(epoch_label, epoch_predicted, average='micro')
                # recall = sklearn.metrics.recall_score(epoch_label, epoch_predicted, average='micro')
                # f1_score = sklearn.metrics.f1_score(epoch_label, epoch_predicted, average='micro')
                accuracy = sklearn.metrics.accuracy_score(
                    epoch_label, epoch_predicted)
                epoch_total_loss = running_total_loss / dataloader.__len__()
                epoch_nll_loss = running_nll_loss / dataloader.__len__()
                epoch_reg_loss = running_reg_loss / dataloader.__len__()

                # print('epoch {} {}_nll_loss: {}'.format(epoch, phase, epoch_nll_loss))
                writer.add_scalars(
                    'nll_loss', {'{}_nll_loss'.format(phase): epoch_nll_loss},
                    epoch)
                writer.add_scalars('accuracy',
                                   {'{}_accuracy'.format(phase): accuracy},
                                   epoch)
                # writer.add_scalars('{}_APRF'.format(phase),
                #                    {
                #                        'accuracy': accuracy,
                #                        'precision': precision,
                #                        'recall': recall,
                #                        'f1_score': f1_score
                #                    },
                #                    epoch)
                if epoch_reg_loss != 0:
                    writer.add_scalars(
                        'reg_loss'.format(phase),
                        {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch)
                # print(epoch_reg_loss)
                # writer.add_histogram('hist/{}_yhat_0'.format(phase),
                #                      epoch_yhat_0,
                #                      epoch)
                # writer.add_histogram('hist/{}_yhat_1'.format(phase),
                #                      epoch_yhat_1,
                #                      epoch)

                # Save Model & Early Stopping
                if phase == 'validation':
                    model_save_path = model_save_dir + '-{}-{}-{:.3f}-{:.3f}'.format(
                        model_name, epoch, accuracy, epoch_nll_loss)
                    # best score
                    if accuracy > best_map:
                        best_map = accuracy
                        model_save_path = model_save_path + '-best'

                    score = epoch_nll_loss
                    if score < best_score:
                        patience_counter = 0
                        best_score = score
                    else:
                        patience_counter += 1

                    # skip first 10 epoch
                    # best_score = best_score if epoch > 10 else -np.inf

                    if save_model:
                        for th, pfix in zip(
                            [0.8, 0.75, 0.7, 0.5, 0.0],
                            ['-perfect', '-great', '-good', '-bad', '-miss']):
                            if accuracy >= th:
                                model_save_path += pfix
                                break

                        torch.save(model.state_dict(), model_save_path)

                    writer.add_scalars('best_val_accuracy',
                                       {'{}_accuracy'.format(phase): best_map},
                                       epoch)
                    writer.add_scalars(
                        'best_nll_loss',
                        {'{}_nll_loss'.format(phase): best_score}, epoch)

                    writer.add_scalars('learning_rate', {
                        'learning_rate':
                        scheduler.optimizer.param_groups[0]['lr']
                    }, epoch)

                    if patience_counter >= patience:
                        print("Stopped at epoch {}".format(epoch))
                        return

                if live_loss:
                    prefix = ''
                    if phase == 'validation':
                        prefix = 'val_'

                    logs[prefix + 'log loss'] = epoch_nll_loss
                    logs[prefix + 'accuracy'] = accuracy
            if live_loss:
                liveloss.update(logs)
                liveloss.draw()

    print("Done !")
                                       batch_size=args.neighbor_batch_size))
            print(
                f'Using neighbor sampling | number of subgraphs: {len(data_list)}'
            )

        # Run the model for each batch size setups
        batch_sizes = np.array(list(range(1, 65))) * 4
        batch_running_time = []
        for batch_size in batch_sizes:
            batch_size = int(batch_size)
            loader = DataListLoader(data_list,
                                    batch_size=batch_size,
                                    shuffle=True)

            # Model hyperparameters
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            lr_scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)

            temp_batch_times = []
            for input_list in loader:
                batch_start = time.time()
                #forward
                output = model(input_list)

                model_output_test(output, input_list)

                _, predicted = torch.max(output, 1)
                y = torch.cat([data.y for data in input_list
                               ]).to(output.device).squeeze()
                loss = F.nll_loss(output, y.long())
Exemple #5
0
class Trainer():
    def __init__(self,
                 option,
                 model,
                 train_dataset,
                 valid_dataset,
                 test_dataset=None):
        self.option = option
        self.device = torch.device("cuda:{}".format(option['cuda_devices'][0]) \
                                       if torch.cuda.is_available() else "cpu")
        self.model = DataParallel(model, device_ids=self.option['cuda_devices']).to(self.device) \
            if option['parallel'] else model.to(self.device)

        # Setting the train valid and test data loader
        if self.option['parallel']:
            self.train_dataloader = DataListLoader(train_dataset, \
                                                   batch_size=self.option['train_batch'])
            self.valid_dataloader = DataListLoader(valid_dataset,
                                                   batch_size=64)
            if test_dataset:
                self.test_dataloader = DataListLoader(test_dataset,
                                                      batch_size=64)
        else:
            self.train_dataloader = DataLoader(train_dataset, \
                                               batch_size=self.option['train_batch'])
            self.valid_dataloader = DataLoader(valid_dataset, batch_size=64)
            if test_dataset:
                self.test_dataloader = DataLoader(test_dataset, batch_size=64)

        # Setting the Adam optimizer with hyper-param
        self.criterion = torch.nn.L1Loss()
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.option['lr'])
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.7,
            patience=self.option['lr_scheduler_patience'],
            min_lr=0.0000001)

        # other
        self.start = time.time()
        self.save_id = ''.join(
            random.sample('zyxwvutsrqponmlkjihgfedcba1234567890', 4))
        self.abs_file_dir = os.path.dirname(os.path.abspath(__file__))
        self.ckpt_save_dir = os.path.join(
            self.abs_file_dir, 'ckpt',
            'ckpts_task{}_{}'.format(self.option['task'], self.save_id))
        self.log_save_path = os.path.join(
            self.abs_file_dir, 'log',
            'log_task{}_{}.txt'.format(self.option['task'], self.save_id))
        self.record_save_path = os.path.join(
            self.abs_file_dir, 'record',
            'record_task{}_{}.csv'.format(self.option['task'], self.save_id))
        os.system('mkdir -p log record {}'.format(self.ckpt_save_dir))
        self.records = {
            'trn_record': [],
            'val_record': [],
            'val_losses': [],
            'best_ckpt': None
        }
        self.log(
            msgs=['\t{}:{}\n'.format(k, v) for k, v in self.option.items()])
        self.log('save id: {}'.format(self.save_id))
        self.log('train set num:{} valid set num:{} test set num: {}'.format(
            len(train_dataset), len(valid_dataset), len(test_dataset)))
        self.log("Total Parameters:" +
                 str(sum([p.nelement() for p in self.model.parameters()])))

    def train_iterations(self, epoch):
        self.model.train()
        losses = []
        for i, data in enumerate(self.train_dataloader):
            self.optimizer.zero_grad()
            if self.option['parallel']:
                sample_list = data  # data will be sample_list in parallel model
                output = self.model(sample_list)
                y = torch.cat([sample.y
                               for sample in sample_list]).to(output.device)
                loss = self.criterion(output, y)
            else:
                data = data.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, data.y)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())
            if i % 100 == 0:
                self.log('\tbatch {} training loss: {:.5f}'.format(
                    i, loss.item()),
                         with_time=True)
        trn_loss = np.array(losses).mean()
        return trn_loss

    def valid_iterations(self, epoch=None, mode='valid'):
        self.model.eval()
        if mode == 'test': dataloader = self.test_dataloader
        if mode == 'valid': dataloader = self.valid_dataloader
        outputs = []
        ys = []
        with torch.no_grad():
            for data in dataloader:
                if self.option['parallel']:
                    sample_list = data  # data will be samplelist in parallel model
                    output = self.model(sample_list)
                    y = torch.cat([sample.y for sample in sample_list
                                   ]).to(output.device)
                else:
                    data = data.to(self.device)
                    output = self.model(data)
                    y = data.y
                outputs.append(output)
                ys.append(y)
        val_loss = self.criterion(torch.cat(outputs), torch.cat(ys)).item()
        if mode == 'test': self.log('Test loss: {:.5f}'.format(val_loss))
        return val_loss

    def train(self):
        self.log('Training start...')
        early_stop_cnt = 0
        for epoch in tqdm(range(self.option['train_epoch'])):
            trn_loss = self.train_iterations(epoch)
            val_loss = self.valid_iterations(epoch)
            self.scheduler.step(val_loss)
            lr_cur = self.scheduler.optimizer.param_groups[0]['lr']
            self.log('Epoch:{} trn_loss:{:.5f} val_loss:{:.5f} lr_cur:{:.7f}'.
                     format(epoch, trn_loss, val_loss, lr_cur),
                     with_time=True)
            self.records['val_losses'].append(val_loss)
            self.records['val_record'].append([epoch, val_loss, lr_cur])
            self.records['trn_record'].append([epoch, trn_loss, lr_cur])
            if val_loss == np.array(self.records['val_losses']).min():
                self.save_model_and_records(epoch, trn_loss, val_loss)
                early_stop_cnt = 0
            else:
                early_stop_cnt += 1
            if self.option[
                    'early_stop_patience'] > 0 and early_stop_cnt > self.option[
                        'early_stop_patience']:
                self.log('Early stop hitted!')
                break
        self.save_model_and_records(epoch, trn_loss, val_loss, final_save=True)

    def save_model_and_records(self,
                               epoch,
                               trn_loss,
                               val_loss,
                               final_save=False):
        if final_save:
            self.save_loss_records()
            file_name = 'Final_save{}_{}_{:.5f}_{:.5f}.ckpt'.format(
                self.option['task'], epoch, trn_loss, val_loss)
        else:
            file_name = 'task{}_{}_{:.5f}_{:.5f}.ckpt'.format(
                self.option['task'], epoch, trn_loss, val_loss)
            self.records['best_ckpt'] = file_name

        with open(os.path.join(self.ckpt_save_dir, file_name), 'wb') as f:
            torch.save(
                {
                    'option': self.option,
                    'records': self.records,
                    'model_state_dict': self.model.state_dict(),
                }, f)
        self.log('Model saved at epoch {}'.format(epoch))

    def save_loss_records(self):
        trn_record = pd.DataFrame(self.records['trn_record'],
                                  columns=['epoch', 'trn_loss', 'lr'])
        val_record = pd.DataFrame(self.records['val_record'],
                                  columns=['epoch', 'val_loss', 'lr'])
        ret = pd.DataFrame({
            'Epoch': trn_record['epoch'],
            'Traning MAE Loss': trn_record['trn_loss'],
            'Validation MAE Loss': val_record['val_loss'],
        })
        ret.to_csv(self.record_save_path)
        return ret

    def load_best_ckpt(self):
        ckpt_path = self.ckpt_save_dir + '/' + self.records['best_ckpt']
        self.log('The best ckpt is {}'.format(ckpt_path))
        self.load_ckpt(ckpt_path)

    def load_ckpt(self, ckpt_path):
        self.log('Ckpt loading: {}'.format(ckpt_path))
        ckpt = torch.load(ckpt_path)
        self.option = ckpt['option']
        self.records = ckpt['records']
        self.model.load_state_dict(ckpt['model_state_dict'])

    def log(self, msg=None, msgs=None, with_time=False):
        if with_time:
            msg = msg + ' Time elapsed {:.2f} hrs ({:.1f} mins)'.format(
                (time.time() - self.start) / 3600.,
                (time.time() - self.start) / 60.)
        with open(self.log_save_path, 'a+') as f:
            if msgs:
                f.writelines(msgs)
                for x in msgs:
                    print(x, end='')
            if msg:
                f.write(msg + '\n')
                print(msg)
Exemple #6
0
def test(args):
    model_train_dict = torch.load(args.checkpoint)

    model = GeometricDataParallel(
        PotentialNetParallel(
            in_channels=20,
            out_channels=1,
            covalent_gather_width=model_train_dict["args"]
            ["covalent_gather_width"],
            non_covalent_gather_width=model_train_dict["args"]
            ["non_covalent_gather_width"],
            covalent_k=model_train_dict["args"]["covalent_k"],
            non_covalent_k=model_train_dict["args"]["non_covalent_k"],
            covalent_neighbor_threshold=model_train_dict["args"]
            ["covalent_threshold"],
            non_covalent_neighbor_threshold=model_train_dict["args"]
            ["non_covalent_threshold"],
        )).float()

    model.load_state_dict(model_train_dict["model_state_dict"])

    dataset_list = []

    # because the script allows for multiple datasets, we iterate over the list of files to build one combined dataset object
    for data in args.test_data:
        dataset_list.append(
            PDBBindDataset(
                data_file=data,
                dataset_name=args.dataset_name,
                feature_type=args.feature_type,
                preprocessing_type=args.preprocessing_type,
                output_info=True,
                cache_data=False,
                use_docking=args.use_docking,
            ))

    dataset = ConcatDataset(dataset_list)
    print("{} complexes in dataset".format(len(dataset)))

    dataloader = DataListLoader(dataset,
                                batch_size=args.batch_size,
                                shuffle=False)

    model.eval()
    model.cuda()

    if args.print_model:
        print(model)
    print("{} total parameters.".format(
        sum(p.numel() for p in model.parameters())))

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    output_f = "{}/{}".format(args.output, args.output_file_name)

    with h5py.File(output_f, "w") as f:

        for batch in tqdm(dataloader):

            batch = [x for x in batch if x is not None]
            if len(batch) < 1:
                continue

            for item in batch:
                name = item[0]
                pose = item[1]
                data = item[2]

                name_grp = f.require_group(str(name))

                name_pose_grp = name_grp.require_group(str(pose))

                y = data.y

                name_pose_grp.attrs["y_true"] = y

                (
                    covalent_feature,
                    non_covalent_feature,
                    pool_feature,
                    fc0_feature,
                    fc1_feature,
                    y_,
                ) = model.module(Batch().from_data_list([data]),
                                 return_hidden_feature=True)

                name_pose_grp.attrs["y_pred"] = y_.cpu().data.numpy()
                hidden_features = np.concatenate(
                    (
                        covalent_feature.cpu().data.numpy(),
                        non_covalent_feature.cpu().data.numpy(),
                        pool_feature.cpu().data.numpy(),
                        fc0_feature.cpu().data.numpy(),
                        fc1_feature.cpu().data.numpy(),
                    ),
                    axis=1,
                )

                name_pose_grp.create_dataset(
                    "hidden_features",
                    (hidden_features.shape[0], hidden_features.shape[1]),
                    data=hidden_features,
                )
Exemple #7
0
def train():

    # set the input channel dims based on featurization type
    if args.feature_type == "pybel":
        feature_size = 20
    else:
        feature_size = 75

    print("found {} datasets in input train-data".format(len(args.train_data)))
    train_dataset_list = []
    val_dataset_list = []

    for data in args.train_data:
        train_dataset_list.append(
            PDBBindDataset(
                data_file=data,
                dataset_name=args.dataset_name,
                feature_type=args.feature_type,
                preprocessing_type=args.preprocessing_type,
                output_info=True,
                use_docking=args.use_docking,
            ))

    for data in args.val_data:
        val_dataset_list.append(
            PDBBindDataset(
                data_file=data,
                dataset_name=args.dataset_name,
                feature_type=args.feature_type,
                preprocessing_type=args.preprocessing_type,
                output_info=True,
                use_docking=args.use_docking,
            ))

    train_dataset = ConcatDataset(train_dataset_list)
    val_dataset = ConcatDataset(val_dataset_list)

    train_dataloader = DataListLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )  # just to keep batch sizes even, since shuffling is used

    val_dataloader = DataListLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )

    tqdm.write("{} complexes in training dataset".format(len(train_dataset)))
    tqdm.write("{} complexes in validation dataset".format(len(val_dataset)))

    model = GeometricDataParallel(
        PotentialNetParallel(
            in_channels=feature_size,
            out_channels=1,
            covalent_gather_width=args.covalent_gather_width,
            non_covalent_gather_width=args.non_covalent_gather_width,
            covalent_k=args.covalent_k,
            non_covalent_k=args.non_covalent_k,
            covalent_neighbor_threshold=args.covalent_threshold,
            non_covalent_neighbor_threshold=args.non_covalent_threshold,
        )).float()

    model.train()
    model.to(0)
    tqdm.write(str(model))
    tqdm.write("{} trainable parameters.".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    tqdm.write("{} total parameters.".format(
        sum(p.numel() for p in model.parameters())))

    criterion = nn.MSELoss().float()
    optimizer = Adam(model.parameters(), lr=args.lr)

    best_checkpoint_dict = None
    best_checkpoint_epoch = 0
    best_checkpoint_step = 0
    best_checkpoint_r2 = -9e9
    step = 0
    for epoch in range(args.epochs):
        losses = []
        for batch in tqdm(train_dataloader):
            batch = [x for x in batch if x is not None]
            if len(batch) < 1:
                print("empty batch, skipping to next batch")
                continue
            optimizer.zero_grad()

            data = [x[2] for x in batch]
            y_ = model(data)
            y = torch.cat([x[2].y for x in batch])

            loss = criterion(y.float(), y_.cpu().float())
            losses.append(loss.cpu().data.item())
            loss.backward()

            y_true = y.cpu().data.numpy()
            y_pred = y_.cpu().data.numpy()

            r2 = r2_score(y_true=y_true, y_pred=y_pred)
            mae = mean_absolute_error(y_true=y_true, y_pred=y_pred)

            pearsonr = stats.pearsonr(y_true.reshape(-1), y_pred.reshape(-1))
            spearmanr = stats.spearmanr(y_true.reshape(-1), y_pred.reshape(-1))

            tqdm.write(
                "epoch: {}\tloss:{:0.4f}\tr2: {:0.4f}\t pearsonr: {:0.4f}\tspearmanr: {:0.4f}\tmae: {:0.4f}\tpred stdev: {:0.4f}"
                "\t pred mean: {:0.4f} \tcovalent_threshold: {:0.4f} \tnon covalent threshold: {:0.4f}"
                .format(
                    epoch,
                    loss.cpu().data.numpy(),
                    r2,
                    float(pearsonr[0]),
                    float(spearmanr[0]),
                    float(mae),
                    np.std(y_pred),
                    np.mean(y_pred),
                    model.module.covalent_neighbor_threshold.t.cpu().data.item(
                    ),
                    model.module.non_covalent_neighbor_threshold.t.cpu().data.
                    item(),
                ))

            if args.checkpoint:
                if step % args.checkpoint_iter == 0:
                    checkpoint_dict = checkpoint_model(
                        model,
                        val_dataloader,
                        epoch,
                        step,
                        args.checkpoint_dir +
                        "/model-epoch-{}-step-{}.pth".format(epoch, step),
                    )
                    if checkpoint_dict["validate_dict"][
                            "r2"] > best_checkpoint_r2:
                        best_checkpoint_step = step
                        best_checkpoint_epoch = epoch
                        best_checkpoint_r2 = checkpoint_dict["validate_dict"][
                            "r2"]
                        best_checkpoint_dict = checkpoint_dict

            optimizer.step()
            step += 1

        if args.checkpoint:
            checkpoint_dict = checkpoint_model(
                model,
                val_dataloader,
                epoch,
                step,
                args.checkpoint_dir +
                "/model-epoch-{}-step-{}.pth".format(epoch, step),
            )
            if checkpoint_dict["validate_dict"]["r2"] > best_checkpoint_r2:
                best_checkpoint_step = step
                best_checkpoint_epoch = epoch
                best_checkpoint_r2 = checkpoint_dict["validate_dict"]["r2"]
                best_checkpoint_dict = checkpoint_dict

    if args.checkpoint:
        # once broken out of the loop, save last model
        checkpoint_dict = checkpoint_model(
            model,
            val_dataloader,
            epoch,
            step,
            args.checkpoint_dir +
            "/model-epoch-{}-step-{}.pth".format(epoch, step),
        )

        if checkpoint_dict["validate_dict"]["r2"] > best_checkpoint_r2:
            best_checkpoint_step = step
            best_checkpoint_epoch = epoch
            best_checkpoint_r2 = checkpoint_dict["validate_dict"]["r2"]
            best_checkpoint_dict = checkpoint_dict

    if args.checkpoint:
        torch.save(best_checkpoint_dict,
                   args.checkpoint_dir + "/best_checkpoint.pth")
    print("best training checkpoint epoch {}/step {} with r2: {}".format(
        best_checkpoint_epoch, best_checkpoint_step, best_checkpoint_r2))
Exemple #8
0
print('Importing structures.')
trainset = Structures(root='./datasets/{}_train/'.format(p.dataset),
                      prefix=p.dataset)
samples = len(trainset)
cutoff = int(np.floor(samples*(1-p.validation_split)))
validset = trainset[cutoff:]
trainset = trainset[:cutoff]


if p.shuffle_dataset:
    trainset = trainset.shuffle()
n_features = trainset.get(0).x.shape[1]
print('Setting up model...')
model = p.model_type(6, heads=p.heads).to(device)
model = DataParallel(model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate, weight_decay=p.weight_decay)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
#                                                       factor=p.lr_decay,
#                                                       patience=p.patience)

writer = SummaryWriter(comment='model:{}_lr:{}_lr_decay:{}_shuffle:{}_seed:{}'.format(
                       p.version,
                       learn_rate,
                       p.lr_decay,
                       p.shuffle_dataset,
                       p.random_seed))


# axes = [0, 1, 2]
max_roc_auc = 0
Exemple #9
0
class Trainer():
    def __init__(self,
                 option,
                 model,
                 train_dataset=None,
                 valid_dataset=None,
                 test_dataset=None,
                 weight=[[1.0, 1.0]],
                 tasks_num=1):
        self.option = option
        # self.tasks = ["MUV-466","MUV-548","MUV-600","MUV-644","MUV-652","MUV-689","MUV-692","MUV-712","MUV-713",
        #               "MUV-733","MUV-737","MUV-810","MUV-832","MUV-846","MUV-852","MUV-858","MUV-859"]
        self.tasks_num = tasks_num

        self.save_path = self.option['exp_path']

        self.device = torch.device("cuda:{}".format(0) \
                                       if torch.cuda.is_available() and not option['cpu']  else "cpu")
        self.model = DataParallel(model).to(self.device) \
            if option['parallel'] else model.to(self.device)

        #Setting the train valid and test data loader
        if train_dataset and valid_dataset:
            if self.option['parallel']:
                self.train_dataloader = DataListLoader(train_dataset, \
                                                       batch_size=self.option['batch_size'],shuffle=True)
                self.valid_dataloader = DataListLoader(
                    valid_dataset, batch_size=self.option['batch_size'])
                if test_dataset:
                    self.test_dataloader = DataListLoader(
                        test_dataset, batch_size=self.option['batch_size'])
            else:
                self.train_dataloader = DataLoader(train_dataset, \
                                                   batch_size=self.option['batch_size'],shuffle=True,num_workers=4)
                self.valid_dataloader = DataLoader(
                    valid_dataset,
                    batch_size=self.option['batch_size'],
                    num_workers=4)
                if test_dataset:
                    self.test_dataloader = DataLoader(
                        test_dataset,
                        batch_size=self.option['batch_size'],
                        num_workers=4)
        else:
            self.test_dataset = test_dataset
            if self.option['parallel']:
                self.test_dataloader = DataListLoader(
                    test_dataset,
                    batch_size=self.option['batch_size'],
                    num_workers=0)
            else:
                self.test_dataloader = DataLoader(
                    test_dataset,
                    batch_size=self.option['batch_size'],
                    num_workers=4)

        # Setting the Adam optimizer with hyper-param

        if not option['focalloss']:
            self.criterion = [
                torch.nn.CrossEntropyLoss(torch.Tensor(w).to(self.device),
                                          reduction='mean') for w in weight
            ]
        else:
            self.log('Using FocalLoss')
            self.criterion = [FocalLoss(alpha=1 / w[0])
                              for w in weight]  #alpha 0.965
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.option['lr'],
                                          weight_decay=option['weight_decay'])
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.7,
            patience=self.option['lr_scheduler_patience'],
            min_lr=1e-6)
        self.start = time.time()
        self.records = {
            'best_epoch': None,
            'val_auc': [],
            'best_val_auc': 0.,
            'best_trn_auc': 0.,
            'best_test_auc': 0.
        }
        self.log(
            msgs=['\t{}:{}\n'.format(k, v) for k, v in self.option.items()],
            show=False)
        if train_dataset:
            self.log(
                'train set num:{}    valid set num:{}    test set num: {}'.
                format(len(train_dataset), len(valid_dataset),
                       len(test_dataset)))
        self.log("total parameters:" +
                 str(sum([p.nelement() for p in self.model.parameters()])))
        self.log(msgs=str(model).split('\n'), show=False)

    def train_iterations(self):
        self.model.train()
        losses = []
        y_out_all = []
        y_pred_list = {}
        y_label_list = {}
        for data in tqdm(self.train_dataloader):
            self.optimizer.zero_grad()

            if not self.option['parallel']:

                data = data.to(self.device)
                target_len = data.pr_len
                y_idx = torch.zeros(target_len[0]).long()
                for i, e in enumerate(target_len):
                    if i > 0:
                        y_idx = torch.cat(
                            [y_idx, torch.full((e.item(), ), i).long()])
                y_idx = y_idx.to(self.device)
                data.protein = torch_geometric.utils.to_dense_batch(
                    data.protein, y_idx)[0]

                output = self.model(data)
            else:
                output = self.model(data)
                data = Batch.from_data_list(data).to(self.device)
            loss = 0
            for i in range(self.tasks_num):
                y_pred = output
                y_label = data.y

                loss += self.criterion[i](y_pred, y_label)

                probs = F.softmax(y_pred.detach().cpu(), dim=-1)

                y_out = probs.argmax(dim=1).numpy()
                y_out_all.extend(y_out)
                y_pred = probs[:, 1].view(-1).numpy()

                # print(i,np.isnan(y_pred).any())
                try:
                    y_label_list[i].extend(y_label.cpu().numpy())
                    y_pred_list[i].extend(y_pred)
                except:
                    y_label_list[i] = []
                    y_pred_list[i] = []
                    y_label_list[i].extend(y_label.cpu().numpy())
                    y_pred_list[i].extend(y_pred)
                # print(np.isnan(y_label_list[i]))

            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())
        val_precision = metrics.precision_score(y_label_list[0],
                                                np.array(y_out_all))
        val_recall = metrics.recall_score(y_label_list[0], np.array(y_out_all))
        trn_roc = [
            metrics.roc_auc_score(y_label_list[i], y_pred_list[i])
            for i in range(self.tasks_num)
        ]
        trn_loss = np.array(losses).mean()

        return trn_loss, np.array(trn_roc).mean(), val_precision, val_recall

    def valid_iterations(self, mode='valid'):
        self.model.eval()
        if mode == 'test' or mode == 'eval': dataloader = self.test_dataloader
        if mode == 'valid': dataloader = self.valid_dataloader
        losses = []
        y_out_all = []
        y_pred_list = {}
        y_label_list = {}
        with torch.no_grad():
            for data in tqdm(dataloader):
                if not self.option['parallel']:
                    data = data.to(self.device)
                    target_len = data.pr_len
                    y_idx = torch.zeros(target_len[0]).long()
                    for i, e in enumerate(target_len):
                        if i > 0:
                            y_idx = torch.cat(
                                [y_idx,
                                 torch.full((e.item(), ), i).long()])
                    y_idx = y_idx.to(self.device)
                    data.protein = torch_geometric.utils.to_dense_batch(
                        data.protein, y_idx)[0]
                    output = self.model(data)
                else:
                    output = self.model(data)
                    data = Batch.from_data_list(data).to(self.device)
                loss = 0
                for i in range(self.tasks_num):
                    y_pred = output
                    y_label = data.y
                    loss = self.criterion[i](y_pred, y_label)
                    probs = F.softmax(y_pred.detach().cpu(), dim=-1)
                    y_out = probs.argmax(dim=1).numpy()
                    y_out_all.extend(y_out)
                    y_pred = probs[:, 1].view(-1).numpy()

                    try:
                        y_label_list[i].extend(y_label.cpu().numpy())
                        y_pred_list[i].extend(y_pred)
                    except:
                        y_label_list[i] = []
                        y_pred_list[i] = []
                        y_label_list[i].extend(y_label.cpu().numpy())
                        y_pred_list[i].extend(y_pred)
                    losses.append(loss.item())

        val_roc = [
            metrics.roc_auc_score(y_label_list[i], y_pred_list[i])
            for i in range(self.tasks_num)
        ]
        val_loss = np.array(losses).mean()

        val_precision = metrics.precision_score(y_label_list[0],
                                                np.array(y_out_all))
        val_recall = metrics.recall_score(y_label_list[0], np.array(y_out_all))

        return val_loss, np.array(val_roc).mean(), val_precision, val_recall

    def train(self):
        self.log('Training start...')
        early_stop_cnt = 0
        for epoch in range(self.option['epochs']):
            trn_loss, trn_roc, trn_pre, trn_recall = self.train_iterations()
            val_loss, val_roc, val_pre, val_recall = self.valid_iterations()
            test_loss, test_roc, test_pre, test_recall = self.valid_iterations(
                mode='test')

            self.scheduler.step(val_loss)
            lr_cur = self.scheduler.optimizer.param_groups[0]['lr']
            self.log(
                'Epoch:{} {} trn_loss:{:.3f} trn_roc:{:.3f} trn_precision:{:.3f} trn_recall:{:.3f} lr_cur:{:.5f}'
                .format(epoch, self.option['dataset'], trn_loss, trn_roc,
                        trn_pre, trn_recall, lr_cur),
                with_time=True)
            self.log(
                'Epoch:{} {} val_loss:{:.3f} val_roc:{:.3f} val_precision:{:.3f} val_recall:{:.3f} lr_cur:{:.5f}'
                .format(epoch, self.option['dataset'], val_loss, val_roc,
                        val_pre, val_recall, lr_cur),
                with_time=True)
            self.log(
                'Epoch:{} {} test_loss:{:.3f} test_roc:{:.3f} test_precision:{:.3f} test_recall:{:.3f} lr_cur:{:.5f}'
                .format(epoch, self.option['dataset'], test_loss, test_roc,
                        test_pre, test_recall, lr_cur),
                with_time=True)
            self.records['val_auc'].append(val_roc)
            if val_roc == np.array(self.records['val_auc']).max():
                self.save_model_and_records(epoch,
                                            test_pre,
                                            test_recall,
                                            test_roc,
                                            final_save=False)
                early_stop_cnt = 0
            else:
                early_stop_cnt += 1
            if self.option[
                    'early_stop_patience'] > 0 and early_stop_cnt > self.option[
                        'early_stop_patience']:
                self.log('Early stop hitted!')
                break
        self.log(
            'The best epoch is {}, test_presion: {}, test_recall: {}, test_auc: {}'
            .format(self.records['best_epoch'],
                    self.records['best_test_precision'],
                    self.records['best_test_recall'],
                    self.records['best_test_auc']))
        return self.records['best_test_auc']

    def predict(self):
        self.model.eval()
        dataloader = self.test_dataloader
        ret = []
        with torch.no_grad():
            for data in tqdm(dataloader):
                if not self.option['parallel']:
                    data = data.to(self.device)
                    target_len = data.pr_len
                    y_idx = torch.zeros(target_len[0]).long()
                    for i, e in enumerate(target_len):
                        if i > 0:
                            y_idx = torch.cat(
                                [y_idx,
                                 torch.full((e.item(), ), i).long()])
                    y_idx = y_idx.to(self.device)
                    data.protein = torch_geometric.utils.to_dense_batch(
                        data.protein, y_idx)[0]
                    output = self.model(data)
                else:
                    output = self.model(data)

                output = F.softmax(output.detach().cpu(),
                                   dim=-1)[:, 1].view(-1).numpy()
                ret.extend(output)

        return ret

    def save_model_and_records(self,
                               epoch,
                               test_pre,
                               test_recall,
                               test_auc=None,
                               final_save=False):
        if final_save:
            # self.save_loss_records()
            file_name = 'best_model_{}.ckpt'.format(self.option['seed'])
        else:
            file_name = 'best_model_{}.ckpt'.format(self.option['seed'])
            self.records['best_epoch'] = epoch
            self.records['best_test_precision'] = test_pre
            self.records['best_test_recall'] = test_recall
            self.records['best_test_auc'] = test_auc

        if not self.option['parallel']:
            model_dic = self.model.state_dict()
        else:
            model_dic = self.model.module.state_dict()

        with open(os.path.join(self.save_path, file_name), 'wb') as f:
            torch.save(model_dic, f)
        self.log('Model saved at epoch {}'.format(epoch))

    def save_loss_records(self):
        trn_record = pd.DataFrame(
            self.records['trn_record'],
            columns=['epoch', 'trn_loss', 'trn_auc', 'trn_acc', 'lr'])
        val_record = pd.DataFrame(
            self.records['val_record'],
            columns=['epoch', 'val_loss', 'val_auc', 'val_acc', 'lr'])
        ret = pd.DataFrame({
            'epoch': trn_record['epoch'],
            'trn_loss': trn_record['trn_loss'],
            'val_loss': val_record['val_loss'],
            'trn_auc': trn_record['trn_auc'],
            'val_auc': val_record['val_auc'],
            'trn_lr': trn_record['lr'],
            'val_lr': val_record['lr']
        })
        ret.to_csv(self.save_path + '/record.csv')
        return ret

    def load_best_ckpt(self):
        ckpt_path = self.save_path + '/' + self.records['best_ckpt']
        self.log('The best ckpt is {}'.format(ckpt_path))
        self.load_ckpt(ckpt_path)

    def load_ckpt(self, ckpt_path):
        self.log('Ckpt loading: {}'.format(ckpt_path))
        ckpt = torch.load(ckpt_path)
        self.option = ckpt['option']
        self.records = ckpt['records']
        self.model.load_state_dict(ckpt['model_state_dict'])

    def log(self, msg=None, msgs=None, with_time=False, show=True):
        if with_time:
            msg = msg + ' time elapsed {:.2f} hrs ({:.1f} mins)'.format(
                (time.time() - self.start) / 3600.,
                (time.time() - self.start) / 60.)
        with open(self.save_path + '/log.txt', 'a+') as f:
            if msgs:
                self.log('#' * 80)
                if '\n' not in msgs[0]: msgs = [m + '\n' for m in msgs]
                f.writelines(msgs)
                if show:
                    for x in msgs:
                        print(x, end='')
                self.log('#' * 80)
            if msg:
                f.write(msg + '\n')
                if show:
                    print(msg)
Exemple #10
0
class Trainer():
    def __init__(self, option, model, train_dataset, valid_dataset, test_dataset=None, weight=[[1.0, 1.0]],
                 tasks_num=17):
        # Most important variable
        self.option = option
        self.device = torch.device("cuda:{}".format(option['gpu'][0]) if torch.cuda.is_available() else "cpu")
        self.model = DataParallel(model).to(self.device) if option['parallel'] else model.to(self.device)

        # Setting the train valid and test data loader
        if self.option['parallel']:
            self.train_dataloader = DataListLoader(train_dataset, batch_size=self.option['batch_size'], shuffle=True)
            self.valid_dataloader = DataListLoader(valid_dataset, batch_size=self.option['batch_size'])
            if test_dataset: self.test_dataloader = DataListLoader(test_dataset, batch_size=self.option['batch_size'])
        else:
            self.train_dataloader = DataLoader(train_dataset, batch_size=self.option['batch_size'], shuffle=True)
            self.valid_dataloader = DataLoader(valid_dataset, batch_size=self.option['batch_size'])
            if test_dataset: self.test_dataloader = DataLoader(test_dataset, batch_size=self.option['batch_size'])
        self.save_path = self.option['exp_path']
        # Setting the Adam optimizer with hyper-param
        if option['focalloss']:
            self.log('Using FocalLoss')
            self.criterion = [FocalLoss(alpha=1 / w[0]) for w in weight]  # alpha 0.965
        else:
            self.criterion = [torch.nn.CrossEntropyLoss(torch.Tensor(w).to(self.device), reduction='mean') for w in
                              weight]
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.option['lr'],
                                          weight_decay=option['weight_decay'])
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.7,
            patience=self.option['lr_scheduler_patience'], min_lr=1e-6
        )

        # other
        self.start = time.time()
        self.tasks_num = tasks_num

        self.records = {'trn_record': [], 'val_record': [], 'val_losses': [],
                        'best_ckpt': None, 'val_roc': [], 'val_prc': []}
        self.log(msgs=['\t{}:{}\n'.format(k, v) for k, v in self.option.items()], show=False)
        self.log('train set num:{}    valid set num:{}    test set num: {}'.format(
            len(train_dataset), len(valid_dataset), len(test_dataset)))
        self.log("total parameters:" + str(sum([p.nelement() for p in self.model.parameters()])))
        self.log(msgs=str(model).split('\n'), show=False)

    def train_iterations(self):
        self.model.train()
        losses = []
        y_pred_list = {}
        y_label_list = {}
        for data in tqdm(self.train_dataloader):
            self.optimizer.zero_grad()

            data = data.to(self.device)
            output = self.model(data)
            loss = 0

            for i in range(self.tasks_num):
                y_pred = output[:, i * 2:(i + 1) * 2]
                y_label = data.y[:, i].squeeze()
                validId = np.where((y_label.cpu().numpy() == 0) | (y_label.cpu().numpy() == 1))[0]

                if len(validId) == 0:
                    continue
                if y_label.dim() == 0:
                    y_label = y_label.unsqueeze(0)

                y_pred = y_pred[torch.tensor(validId).to(self.device)]
                y_label = y_label[torch.tensor(validId).to(self.device)]

                loss += self.criterion[i](y_pred, y_label)
                y_pred = F.softmax(y_pred.detach().cpu(), dim=-1)[:, 1].view(-1).numpy()
                try:
                    y_label_list[i].extend(y_label.cpu().numpy())
                    y_pred_list[i].extend(y_pred)
                except:
                    y_label_list[i] = []
                    y_pred_list[i] = []
                    y_label_list[i].extend(y_label.cpu().numpy())
                    y_pred_list[i].extend(y_pred)

            loss.backward()
            # for name,parms in self.model.named_parameters():
            #     print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
            #           ' -->grad_value:', parms.grad)
            self.optimizer.step()
            losses.append(loss.item())

        trn_roc = [metrics.roc_auc_score(y_label_list[i], y_pred_list[i]) for i in range(self.tasks_num)]
        trn_prc = [metrics.auc(precision_recall_curve(y_label_list[i], y_pred_list[i])[1],
                               precision_recall_curve(y_label_list[i], y_pred_list[i])[0]) for i in
                   range(self.tasks_num)]
        trn_loss = np.array(losses).mean()

        return trn_loss, np.array(trn_roc).mean(), np.array(trn_prc).mean()

    def valid_iterations(self, mode='valid'):
        self.model.eval()
        if mode == 'test' or mode == 'eval': dataloader = self.test_dataloader
        if mode == 'valid': dataloader = self.valid_dataloader
        losses = []
        y_pred_list = {}
        y_label_list = {}
        with torch.no_grad():
            for data in tqdm(dataloader):
                data = data.to(self.device)
                output = self.model(data)
                loss = 0
                for i in range(self.tasks_num):
                    y_pred = output[:, i * 2:(i + 1) * 2]
                    y_label = data.y[:, i].squeeze()
                    validId = np.where((y_label.cpu().numpy() == 0) | (y_label.cpu().numpy() == 1))[0]
                    if len(validId) == 0:
                        continue
                    if y_label.dim() == 0:
                        y_label = y_label.unsqueeze(0)
                    # print(len(validId))
                    # try:
                    #     print(len(y_label))
                    # except:
                    #     print(y_label)
                    #     print(data.y[:,i])
                    #     print(data.y)

                    y_pred = y_pred[torch.tensor(validId).to(self.device)]
                    y_label = y_label[torch.tensor(validId).to(self.device)]

                    loss = self.criterion[i](y_pred, y_label)

                    y_pred = F.softmax(y_pred.detach().cpu(), dim=-1)[:, 1].view(-1).numpy()

                    try:
                        y_label_list[i].extend(y_label.cpu().numpy())
                        y_pred_list[i].extend(y_pred)
                    except:
                        y_label_list[i] = []
                        y_pred_list[i] = []
                        y_label_list[i].extend(y_label.cpu().numpy())
                        y_pred_list[i].extend(y_pred)
                    losses.append(loss.item())

        val_roc = [metrics.roc_auc_score(y_label_list[i], y_pred_list[i]) for i in range(self.tasks_num)]
        val_prc = [metrics.auc(precision_recall_curve(y_label_list[i], y_pred_list[i])[1],
                               precision_recall_curve(y_label_list[i], y_pred_list[i])[0]) for i in
                   range(self.tasks_num)]
        val_loss = np.array(losses).mean()
        if mode == 'eval':
            self.log('SEED {} DATASET {}  The best test_loss:{:.3f} test_roc:{:.3f} test_prc:{:.3f}.'
                     .format(self.option['seed'], self.option['dataset'], val_loss, np.array(val_roc).mean(),
                             np.array(val_prc).mean()))

        return val_loss, np.array(val_roc).mean(), np.array(val_prc).mean()

    def train(self):
        self.log('Training start...')
        early_stop_cnt = 0
        for epoch in range(self.option['epochs']):
            trn_loss, trn_roc, trn_prc = self.train_iterations()
            val_loss, val_roc, val_prc = self.valid_iterations()
            test_loss, test_roc, test_prc = self.valid_iterations(mode='test')

            self.scheduler.step(val_loss)
            lr_cur = self.scheduler.optimizer.param_groups[0]['lr']

            self.log('Epoch:{} {} trn_loss:{:.3f} trn_roc:{:.3f} trn_prc:{:.3f} lr_cur:{:.5f}'.
                     format(epoch, self.option['dataset'], trn_loss, trn_roc, trn_prc, lr_cur),
                     with_time=True)
            self.log('Epoch:{} {} val_loss:{:.3f} val_roc:{:.3f} val_prc:{:.3f} lr_cur:{:.5f}'.
                     format(epoch, self.option['dataset'], val_loss, val_roc, val_prc, lr_cur),
                     with_time=True)
            self.log('Epoch:{} {} test_loss:{:.3f} test_roc:{:.3f} test_prc:{:.3f} lr_cur:{:.5f}'.
                     format(epoch, self.option['dataset'], test_loss, test_roc, test_prc, lr_cur),
                     with_time=True)

            self.records['val_roc'].append(val_roc)
            self.records['val_prc'].append(val_prc)
            self.records['val_record'].append([epoch, val_loss, val_roc, val_prc, lr_cur])
            self.records['trn_record'].append([epoch, trn_loss, trn_roc, trn_prc, lr_cur])
            if val_roc == np.array(self.records['val_roc']).max() or val_prc == np.array(self.records['val_prc']).max():
                self.save_model_and_records(epoch)
                early_stop_cnt = 0
            else:
                early_stop_cnt += 1
            if self.option['early_stop_patience'] > 0 and early_stop_cnt > self.option['early_stop_patience']:
                self.log('Early stop hitted!')
                break
        self.save_model_and_records(epoch, final_save=True)

    def save_model_and_records(self, epoch, final_save=False):
        if final_save:
            self.save_loss_records()
            file_name = 'best_model.ckpt'
        else:
            file_name = 'best_model.ckpt'
            self.records['best_ckpt'] = file_name

        with open(os.path.join(self.save_path, file_name), 'wb') as f:
            torch.save({
                'option': self.option,
                'records': self.records,
                'model_state_dict': self.model.state_dict(),
            }, f)
        self.log('Model saved at epoch {}'.format(epoch))

    def save_loss_records(self):
        trn_record = pd.DataFrame(self.records['trn_record'],
                                  columns=['epoch', 'trn_loss', 'trn_auc', 'trn_acc', 'lr'])
        val_record = pd.DataFrame(self.records['val_record'],
                                  columns=['epoch', 'val_loss', 'val_auc', 'val_acc', 'lr'])
        ret = pd.DataFrame({
            'epoch': trn_record['epoch'],
            'trn_loss': trn_record['trn_loss'],
            'val_loss': val_record['val_loss'],
            'trn_auc': trn_record['trn_auc'],
            'val_auc': val_record['val_auc'],
            'trn_lr': trn_record['lr'],
            'val_lr': val_record['lr']
        })
        ret.to_csv(self.save_path + '/record.csv')
        return ret

    def load_best_ckpt(self):
        ckpt_path = self.save_path + '/' + self.records['best_ckpt']
        self.log('The best ckpt is {}'.format(ckpt_path))
        self.load_ckpt(ckpt_path)

    def load_ckpt(self, ckpt_path):
        self.log('Ckpt loading: {}'.format(ckpt_path))
        ckpt = torch.load(ckpt_path)
        self.option = ckpt['option']
        self.records = ckpt['records']
        self.model.load_state_dict(ckpt['model_state_dict'])

    def log(self, msg=None, msgs=None, with_time=False, show=True):
        if with_time: msg = msg + ' time elapsed {:.2f} hrs ({:.1f} mins)'.format(
            (time.time() - self.start) / 3600.,
            (time.time() - self.start) / 60.
        )
        with open(self.save_path + '/log.txt', 'a+') as f:
            if msgs:
                self.log('#' * 80)
                if '\n' not in msgs[0]: msgs = [m + '\n' for m in msgs]
                f.writelines(msgs)
                if show:
                    for x in msgs:
                        print(x, end='')
                self.log('#' * 80)
            if msg:
                f.write(msg + '\n')
                if show:
                    print(msg)
Exemple #11
0
def train_cross_validation(model_cls,
                           dataset,
                           num_clusters,
                           dropout=0.0,
                           lr=1e-4,
                           weight_decay=1e-2,
                           num_epochs=200,
                           n_splits=10,
                           use_gpu=True,
                           dp=False,
                           ddp=True,
                           comment='',
                           tb_service_loc='192.168.192.57:6006',
                           batch_size=1,
                           num_workers=0,
                           pin_memory=False,
                           cuda_device=None,
                           fold_no=None,
                           saved_model_path=None,
                           device_ids=None,
                           patience=50,
                           seed=None,
                           save_model=True,
                           c_reg=0,
                           base_log_dir='runs',
                           base_model_save_dir='saved_models'):
    """
    :param c_reg:
    :param save_model: bool
    :param seed:
    :param patience: for early stopping
    :param device_ids: for ddp
    :param saved_model_path:
    :param fold_no:
    :param ddp: DDP
    :param cuda_device:
    :param pin_memory: DataLoader args https://devblogs.nvidia.com/how-optimize-data-transfers-cuda-cc/
    :param num_workers: DataLoader args
    :param model_cls: pytorch Module cls
    :param dataset: pytorch Dataset cls
    :param dropout:
    :param lr:
    :param weight_decay:
    :param num_epochs:
    :param n_splits: number of kFolds
    :param use_gpu: bool
    :param dp: bool
    :param comment: comment in the logs, to filter runs in tensorboard
    :param tb_service_loc: tensorboard service location
    :param batch_size: Dataset args not DataLoader
    :return:
    """
    saved_args = locals()
    seed = int(time.time() % 1e4 * 1e5) if seed is None else seed
    saved_args['random_seed'] = seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)

    if ddp and not torch.distributed.is_initialized():  # initialize ddp
        dist.init_process_group('nccl',
                                init_method='tcp://localhost:{}'.format(
                                    find_open_port()),
                                world_size=1,
                                rank=0)

    model_name = model_cls.__name__

    if not cuda_device:
        if device_ids and (ddp or dp):
            device = device_ids[0]
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    else:
        device = cuda_device

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None and
                                       (dp or ddp)) else device_count
    if device_count > 1:
        print("Let's use", device_count, "GPUs!")

    # batch_size = batch_size * device_count

    log_dir_base = get_model_log_dir(comment, model_name)
    if tb_service_loc is not None:
        print("TensorBoard available at http://{1}/#scalars&regexInput={0}".
              format(log_dir_base, tb_service_loc))
    else:
        print("Please set up TensorBoard")

    criterion = nn.CrossEntropyLoss()

    # get test set
    folds = StratifiedKFold(n_splits=n_splits, shuffle=False)
    train_val_idx, test_idx = list(
        folds.split(np.zeros(len(dataset)), dataset.data.y.numpy()))[0]
    test_dataset = dataset.__indexing__(test_idx)
    train_val_dataset = dataset.__indexing__(train_val_idx)

    print("Training {0} {1} models for cross validation...".format(
        n_splits, model_name))
    # folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0
    folds = StratifiedKFold(n_splits=n_splits, shuffle=False)
    iter = folds.split(np.zeros(len(train_val_dataset)),
                       train_val_dataset.data.y.numpy())
    fold = 0

    for train_idx, val_idx in tqdm_notebook(iter, desc='CV', leave=False):

        fold += 1
        if fold_no is not None:
            if fold != fold_no:
                continue

        writer = SummaryWriter(log_dir=osp.join(base_log_dir, log_dir_base +
                                                str(fold)))
        model_save_dir = osp.join(base_model_save_dir,
                                  log_dir_base + str(fold))

        print("creating dataloader tor fold {}".format(fold))

        model = model_cls(writer,
                          num_clusters=num_clusters,
                          in_dim=dataset.data.x.shape[1],
                          out_dim=int(dataset.data.y.max() + 1),
                          dropout=dropout)

        # My Batch

        train_dataset = train_val_dataset.__indexing__(train_idx)
        val_dataset = train_val_dataset.__indexing__(val_idx)

        train_dataset = dataset_gather(
            train_dataset,
            seed=0,
            n_repeat=1,
            n_splits=int(len(train_dataset) / batch_size) + 1)
        val_dataset = dataset_gather(
            val_dataset,
            seed=0,
            n_repeat=1,
            n_splits=int(len(val_dataset) / batch_size) + 1)

        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=device_count,
                                      collate_fn=lambda data_list: data_list,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory)
        val_dataloader = DataLoader(val_dataset,
                                    shuffle=False,
                                    batch_size=device_count,
                                    collate_fn=lambda data_list: data_list,
                                    num_workers=num_workers,
                                    pin_memory=pin_memory)

        # if fold == 1 or fold_no is not None:
        print(model)
        writer.add_text('model_summary', model.__repr__())
        writer.add_text('training_args', str(saved_args))

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=weight_decay,
                                     amsgrad=False)
        # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
        if ddp:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=device_ids)
        elif dp and use_gpu:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = DataParallel(model, device_ids=device_ids)
        elif use_gpu:
            model = model.to(device)

        if saved_model_path is not None:
            model.load_state_dict(torch.load(saved_model_path))

        best_map, patience_counter, best_score = 0.0, 0, -np.inf
        for epoch in tqdm_notebook(range(1, num_epochs + 1),
                                   desc='Epoch',
                                   leave=False):

            for phase in ['train', 'validation']:

                if phase == 'train':
                    model.train()
                    dataloader = train_dataloader
                else:
                    model.eval()
                    dataloader = val_dataloader

                # Logging
                running_total_loss = 0.0
                running_corrects = 0
                running_reg_loss = 0.0
                running_nll_loss = 0.0
                epoch_yhat_0, epoch_yhat_1 = torch.tensor([]), torch.tensor([])
                epoch_label, epoch_predicted = torch.tensor([]), torch.tensor(
                    [])

                for data_list in tqdm_notebook(dataloader,
                                               desc=phase,
                                               leave=False):

                    # TODO: check devices
                    if dp:
                        data_list = to_cuda(data_list,
                                            (device_ids[0] if device_ids
                                             is not None else 'cuda'))

                    y_hat, reg = model(data_list)
                    # y_hat = y_hat.reshape(batch_size, -1)

                    y = torch.tensor([],
                                     dtype=dataset.data.y.dtype,
                                     device=device)
                    for data in data_list:
                        y = torch.cat([y, data.y.view(-1).to(device)])

                    loss = criterion(y_hat, y)
                    reg_loss = -reg
                    total_loss = (loss + reg_loss * c_reg).sum()

                    if phase == 'train':
                        # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True))
                        optimizer.zero_grad()
                        total_loss.backward(retain_graph=True)
                        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                        optimizer.step()

                    _, predicted = torch.max(y_hat, 1)
                    label = y

                    running_nll_loss += loss.item()
                    running_total_loss += total_loss.item()
                    running_reg_loss += reg.sum().item()
                    running_corrects += (predicted == label).sum().item()

                    epoch_yhat_0 = torch.cat(
                        [epoch_yhat_0, y_hat[:, 0].detach().view(-1).cpu()])
                    epoch_yhat_1 = torch.cat(
                        [epoch_yhat_1, y_hat[:, 1].detach().view(-1).cpu()])
                    epoch_label = torch.cat(
                        [epoch_label,
                         label.detach().cpu().float()])
                    epoch_predicted = torch.cat(
                        [epoch_predicted,
                         predicted.detach().cpu().float()])

                # precision = sklearn.metrics.precision_score(epoch_label, epoch_predicted, average='micro')
                # recall = sklearn.metrics.recall_score(epoch_label, epoch_predicted, average='micro')
                # f1_score = sklearn.metrics.f1_score(epoch_label, epoch_predicted, average='micro')
                accuracy = sklearn.metrics.accuracy_score(
                    epoch_label, epoch_predicted)
                epoch_total_loss = running_total_loss / dataloader.__len__()
                epoch_nll_loss = running_nll_loss / dataloader.__len__()
                epoch_reg_loss = running_reg_loss / dataloader.dataset.__len__(
                )

                writer.add_scalars(
                    'nll_loss', {'{}_nll_loss'.format(phase): epoch_nll_loss},
                    epoch)
                writer.add_scalars('accuracy',
                                   {'{}_accuracy'.format(phase): accuracy},
                                   epoch)
                # writer.add_scalars('{}_APRF'.format(phase),
                #                    {
                #                        'accuracy': accuracy,
                #                        'precision': precision,
                #                        'recall': recall,
                #                        'f1_score': f1_score
                #                    },
                #                    epoch)
                if epoch_reg_loss != 0:
                    writer.add_scalars(
                        'reg_loss'.format(phase),
                        {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch)
                # writer.add_histogram('hist/{}_yhat_0'.format(phase),
                #                      epoch_yhat_0,
                #                      epoch)
                # writer.add_histogram('hist/{}_yhat_1'.format(phase),
                #                      epoch_yhat_1,
                #                      epoch)

                # Save Model & Early Stopping
                if phase == 'validation':
                    model_save_path = model_save_dir + '-{}-{}-{:.3f}-{:.3f}'.format(
                        model_name, epoch, accuracy, epoch_nll_loss)
                    if accuracy > best_map:
                        best_map = accuracy
                        model_save_path = model_save_path + '-best'

                    score = -epoch_nll_loss
                    if score > best_score:
                        patience_counter = 0
                        best_score = score
                    else:
                        patience_counter += 1

                    # skip 10 epoch
                    # best_score = best_score if epoch > 10 else -np.inf

                    if save_model:
                        for th, pfix in zip(
                            [0.8, 0.75, 0.7, 0.5, 0.0],
                            ['-perfect', '-great', '-good', '-bad', '-miss']):
                            if accuracy >= th:
                                model_save_path += pfix
                                break
                        if epoch > 10:
                            torch.save(model.state_dict(), model_save_path)

                    writer.add_scalars('best_val_accuracy',
                                       {'{}_accuracy'.format(phase): best_map},
                                       epoch)
                    writer.add_scalars(
                        'best_nll_loss',
                        {'{}_nll_loss'.format(phase): -best_score}, epoch)

                    if patience_counter >= patience:
                        print("Stopped at epoch {}".format(epoch))
                        return

    print("Done !")
Exemple #12
0
def train_cummunity_detection(model_cls,
                              dataset,
                              dropout=0.0,
                              lr=1e-3,
                              weight_decay=1e-2,
                              num_epochs=200,
                              n_splits=10,
                              use_gpu=True,
                              dp=False,
                              ddp=False,
                              comment='',
                              tb_service_loc='192.168.192.57:6006',
                              batch_size=1,
                              num_workers=0,
                              pin_memory=False,
                              cuda_device=None,
                              ddp_port='23456',
                              fold_no=None,
                              device_ids=None,
                              patience=20,
                              seed=None,
                              save_model=False,
                              supervised=False):
    """
    :param save_model: bool
    :param seed:
    :param patience: for early stopping
    :param device_ids: for ddp
    :param saved_model_path:
    :param fold_no:
    :param ddp_port:
    :param ddp: DDP
    :param cuda_device:
    :param pin_memory: DataLoader args https://devblogs.nvidia.com/how-optimize-data-transfers-cuda-cc/
    :param num_workers: DataLoader args
    :param model_cls: pytorch Module cls
    :param dataset: pytorch Dataset cls
    :param dropout:
    :param lr:
    :param weight_decay:
    :param num_epochs:
    :param n_splits: number of kFolds
    :param use_gpu: bool
    :param dp: bool
    :param comment: comment in the logs, to filter runs in tensorboard
    :param tb_service_loc: tensorboard service location
    :param batch_size: Dataset args not DataLoader
    :return:
    """

    saved_args = locals()
    seed = int(time.time() % 1e4 * 1e5) if seed is None else seed
    saved_args['random_seed'] = seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)

    if ddp and not torch.distributed.is_initialized():  # initialize ddp
        dist.init_process_group(
            'nccl',
            init_method='tcp://localhost:{}'.format(ddp_port),
            world_size=1,
            rank=0)

    model_name = model_cls.__name__

    if not cuda_device:
        if device_ids and (ddp or dp):
            device = device_ids[0]
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    else:
        device = cuda_device

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None and
                                       (dp or ddp)) else device_count
    if device_count > 1:
        print("Let's use", device_count, "GPUs!")

    # batch_size = batch_size * device_count

    log_dir_base = get_model_log_dir(comment, model_name)
    if tb_service_loc is not None:
        print("TensorBoard available at http://{1}/#scalars&regexInput={0}".
              format(log_dir_base, tb_service_loc))
    else:
        print("Please set up TensorBoard")

    print("Training {0} {1} models for cross validation...".format(
        n_splits, model_name))
    folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0
    print(dataset.__len__())

    for train_idx, test_idx in tqdm_notebook(folds.split(
            list(range(dataset.__len__())), list(range(dataset.__len__()))),
                                             desc='models',
                                             leave=False):
        fold += 1
        if fold_no is not None:
            if fold != fold_no:
                continue

        writer = SummaryWriter(log_dir=osp.join('runs', log_dir_base +
                                                str(fold)))
        model_save_dir = osp.join('saved_models', log_dir_base + str(fold))

        print("creating dataloader tor fold {}".format(fold))

        model = model_cls(writer, dropout=dropout)

        # My Batch
        train_dataset = dataset.__indexing__(train_idx)
        test_dataset = dataset.__indexing__(test_idx)

        train_dataset = dataset_gather(train_dataset,
                                       n_repeat=1,
                                       n_splits=int(
                                           len(train_dataset) / batch_size))

        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=device_count,
                                      collate_fn=lambda data_list: data_list,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory)
        test_dataloader = DataLoader(test_dataset,
                                     shuffle=False,
                                     batch_size=device_count,
                                     collate_fn=lambda data_list: data_list,
                                     num_workers=num_workers,
                                     pin_memory=pin_memory)

        print(model)
        writer.add_text('model_summary', model.__repr__())
        writer.add_text('training_args', str(saved_args))

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=weight_decay,
                                     amsgrad=False)
        # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
        if ddp:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=device_ids)
        elif dp and use_gpu:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = DataParallel(model, device_ids=device_ids)
        elif use_gpu:
            model = model.to(device)

        for epoch in tqdm_notebook(range(1, num_epochs + 1),
                                   desc='Epoch',
                                   leave=False):

            for phase in ['train', 'validation']:

                if phase == 'train':
                    model.train()
                    dataloader = train_dataloader
                else:
                    model.eval()
                    dataloader = test_dataloader

                # Logging
                running_total_loss = 0.0
                running_reg_loss = 0.0
                running_overlap = 0.0

                for data_list in tqdm_notebook(dataloader,
                                               desc=phase,
                                               leave=False):

                    # TODO: check devices
                    if dp:
                        data_list = to_cuda(data_list,
                                            (device_ids[0] if device_ids
                                             is not None else 'cuda'))

                    y_hat, reg = model(data_list)

                    y = torch.tensor([],
                                     dtype=dataset.data.y.dtype,
                                     device=device)
                    for data in data_list:
                        y = torch.cat([y, data.y.view(-1).to(device)])

                    if supervised:
                        loss = permutation_invariant_loss(y_hat, y)
                        # criterion = nn.NLLLoss()
                        # loss = criterion(y_hat, y)
                    else:
                        loss = -reg
                    total_loss = loss

                    if phase == 'train':
                        # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True))
                        optimizer.zero_grad()
                        total_loss.backward(retain_graph=True)
                        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                        optimizer.step()

                    _, predicted = torch.max(y_hat, 1)
                    label = y

                    if supervised:
                        overlap_score = normalized_overlap(
                            label.int().cpu().numpy(),
                            predicted.int().cpu().numpy(), 0.25)
                        # overlap_score = overlap(label.int().cpu().numpy(), predicted.int().cpu().numpy())
                        running_overlap += overlap_score
                        print(reg, overlap_score, loss)

                    running_total_loss += total_loss.item()
                    running_reg_loss += reg.sum().item()

                epoch_total_loss = running_total_loss / dataloader.__len__()
                epoch_reg_loss = running_reg_loss / dataloader.dataset.__len__(
                )
                if supervised:
                    epoch_overlap = running_overlap / dataloader.__len__()
                    writer.add_scalars(
                        'overlap'.format(phase),
                        {'{}_overlap'.format(phase): epoch_overlap}, epoch)

                writer.add_scalars(
                    'reg_loss'.format(phase),
                    {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch)

    print("Done !")