Beispiel #1
0
def init_model(model_cls,
               log_dir_base,
               fold_no,
               device_ids=None,
               use_gpu=False,
               dp=False,
               ddp=False,
               tb_dir='runs',
               lr=1e-3,
               weight_decay=1e-2):
    writer = SummaryWriter(log_dir=osp.join(tb_dir, log_dir_base))

    model = model_cls(writer)

    writer.add_text('model_summary', model.__repr__())

    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=lr,
                                  betas=(0.9, 0.999),
                                  eps=1e-08,
                                  weight_decay=weight_decay,
                                  amsgrad=False)
    # scheduler_reduce = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    # scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=5)
    # scheduler = scheduler_reduce
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    if dp and use_gpu:
        model = model.cuda() if device_ids is None else model.to(device_ids[0])
        model = DataParallel(model, device_ids=device_ids)
    elif use_gpu:
        model = model.to(device_ids[0])

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None
                                       and dp) else device_count

    return model, optimizer, writer, device_count
Beispiel #2
0
def model_training(data_list_train, data_list_test, epochs, acc_epoch, acc_epoch2, save_model_epochs, validation_epoch, batchsize, logfilename, load_checkpoint= None):
        
    #logging
    logging.basicConfig(level=logging.DEBUG, filename='./logfiles/'+logfilename, filemode="w+",
                        format="%(message)s")
    trainloader = DataListLoader(data_list_train, batch_size=batchsize, shuffle=True)
    testloader = DataListLoader(data_list_test, batch_size=batchsize, shuffle=True)
    device = torch.device('cuda')
    complete_net = completeNet()
    complete_net = DataParallel(complete_net)
    complete_net = complete_net.to(device)
    
    #train parameters
    weights = [10, 1]
    optimizer = torch.optim.Adam(complete_net.parameters(), lr=0.001, weight_decay=0.001)

    #resume training
    initial_epoch=1
    if load_checkpoint!=None:
        checkpoint = torch.load(load_checkpoint)
        complete_net.load_state_dict(checkpoint['model_state_dict'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initial_epoch = checkpoint['epoch']+1
        loss = checkpoint['loss']
    
    complete_net.train()

    for epoch in range(initial_epoch, epochs+1):
        epoch_total=0
        epoch_total_ones= 0
        epoch_total_zeros= 0
        epoch_correct=0
        epoch_correct_ones= 0
        epoch_correct_zeros= 0
        running_loss= 0
        batches_num=0         
        for batch in trainloader:
            batch_total=0
            batch_total_ones= 0
            batch_total_zeros= 0
            batch_correct= 0
            batch_correct_ones= 0
            batch_correct_zeros= 0
            batches_num+=1
            # Forward-Backpropagation
            output, output2, ground_truth, ground_truth2, det_num, tracklet_num= complete_net(batch)
            optimizer.zero_grad()
            loss = weighted_binary_cross_entropy(output, ground_truth, weights)
            loss.backward()
            optimizer.step()
            ##Accuracy 
            if epoch%acc_epoch==0 and epoch!=0:
                # Hungarian method, clean up
                cleaned_output= hungarian(output2, ground_truth2, det_num, tracklet_num)
                batch_total += cleaned_output.size(0)
                ones= torch.tensor([1 for x in cleaned_output]).to(device)
                zeros = torch.tensor([0 for x in cleaned_output]).to(device)
                batch_total_ones += (cleaned_output == ones).sum().item()
                batch_total_zeros += (cleaned_output == zeros).sum().item()
                batch_correct += (cleaned_output == ground_truth2).sum().item()
                temp1 = (cleaned_output == ground_truth2)
                temp2 = (cleaned_output == ones)
                batch_correct_ones += (temp1 & temp2).sum().item()
                temp3 = (cleaned_output == zeros)
                batch_correct_zeros += (temp1 & temp3).sum().item()
                epoch_total += batch_total
                epoch_total_ones += batch_total_ones
                epoch_total_zeros += batch_total_zeros
                epoch_correct += batch_correct
                epoch_correct_ones += batch_correct_ones
                epoch_correct_zeros += batch_correct_zeros
            if loss.item()!=loss.item():
                print("Error")
                break
            if batch_total_ones != 0 and batch_total_zeros != 0 and epoch%acc_epoch==0 and epoch!=0:
                print('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, batches_num, loss.item(), 100 * batch_correct / batch_total, 100 * batch_correct_ones / batch_total_ones,
                       100 * batch_correct_zeros / batch_total_zeros))
                logging.info('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, batches_num, loss.item(), 100 * batch_correct / batch_total, 100 * batch_correct_ones / batch_total_ones,
                       100 * batch_correct_zeros / batch_total_zeros))
            else:
                print('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f |' %
                        (epoch, batches_num, loss.item()))
                logging.info('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f |' %
                        (epoch, batches_num, loss.item()))
            running_loss += loss.item()
        if loss.item()!=loss.item():
                print("Error")
                break
        if epoch_total_ones!=0 and epoch_total_zeros!=0 and epoch%acc_epoch==0 and epoch!=0:
            print('Epoch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                          epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
            logging.info('Epoch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                      (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                          epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
        else:
            print('Epoch: [%d] | Training_Loss: %.3f |' %
                        (epoch, running_loss / batches_num))
            logging.info('Epoch: [%d] | Training_Loss: %.3f |' %
                        (epoch, running_loss / batches_num))
        # save model
        if epoch%save_model_epochs==0 and epoch!=0:
            torch.save({ 
                        'epoch': epoch,
                        'model_state_dict': complete_net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': running_loss,
                        }, './models/epoch_'+str(epoch)+'.pth')

        #validation
        if epoch%validation_epoch==0 and epoch!=0:
            with torch.no_grad():
                epoch_total=0
                epoch_total_ones= 0
                epoch_total_zeros= 0
                epoch_correct=0
                epoch_correct_ones= 0
                epoch_correct_zeros= 0
                running_loss= 0
                batches_num=0
                for batch in testloader:
                    batch_total=0
                    batch_total_ones= 0
                    batch_total_zeros= 0
                    batch_correct= 0
                    batch_correct_ones= 0
                    batch_correct_zeros= 0
                    batches_num+=1
                    output, output2, ground_truth, ground_truth2, det_num, tracklet_num = complete_net(batch)
                    loss = weighted_binary_cross_entropy(output, ground_truth, weights)
                    running_loss += loss.item()
                    ##Accuracy 
                    if epoch%acc_epoch2==0 and epoch!=0:
                        # Hungarian method, clean up
                        cleaned_output= hungarian(output2, ground_truth2, det_num, tracklet_num)
                        batch_total += cleaned_output.size(0)
                        ones= torch.tensor([1 for x in cleaned_output]).to(device)
                        zeros = torch.tensor([0 for x in cleaned_output]).to(device)
                        batch_total_ones += (cleaned_output == ones).sum().item()
                        batch_total_zeros += (cleaned_output == zeros).sum().item()
                        batch_correct += (cleaned_output == ground_truth2).sum().item()
                        temp1 = (cleaned_output == ground_truth2)
                        temp2 = (cleaned_output == ones)
                        batch_correct_ones += (temp1 & temp2).sum().item()
                        temp3 = (cleaned_output == zeros)
                        batch_correct_zeros += (temp1 & temp3).sum().item()
                        epoch_total += batch_total
                        epoch_total_ones += batch_total_ones
                        epoch_total_zeros += batch_total_zeros
                        epoch_correct += batch_correct
                        epoch_correct_ones += batch_correct_ones
                        epoch_correct_zeros += batch_correct_zeros
                if epoch_total_ones!=0 and epoch_total_zeros!=0 and epoch%acc_epoch2==0 and epoch!=0:
                    print('Epoch: [%d] | Validation_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                                (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                                    epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
                    logging.info('Epoch: [%d] | Validation_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' %
                                (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \
                                    epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros))
                else:
                    print('Epoch: [%d] | Validation_Loss: %.3f |' %
                                (epoch, running_loss / batches_num))
                    logging.info('Epoch: [%d] | Validation_Loss: %.3f |' %
                                (epoch, running_loss / batches_num))
Beispiel #3
0
def train_cross_validation(model_cls,
                           dataset,
                           dropout=0.0,
                           lr=1e-3,
                           weight_decay=1e-2,
                           num_epochs=200,
                           n_splits=10,
                           use_gpu=True,
                           dp=False,
                           ddp=False,
                           comment='',
                           tb_service_loc='192.168.192.57:6007',
                           batch_size=1,
                           num_workers=0,
                           pin_memory=False,
                           cuda_device=None,
                           tb_dir='runs',
                           model_save_dir='saved_models',
                           res_save_dir='res',
                           fold_no=None,
                           saved_model_path=None,
                           device_ids=None,
                           patience=20,
                           seed=None,
                           fold_seed=None,
                           save_model=False,
                           is_reg=True,
                           live_loss=True,
                           domain_cls=True,
                           final_cls=True):
    """
    :type fold_seed: int
    :param live_loss: bool
    :param is_reg: bool
    :param save_model: bool
    :param seed:
    :param patience: for early stopping
    :param device_ids: for ddp
    :param saved_model_path:
    :param fold_no: int
    :param ddp_port: str
    :param ddp: DDP
    :param cuda_device: list of int
    :param pin_memory: bool, DataLoader args
    :param num_workers: int, DataLoader args
    :param model_cls: pytorch Module cls
    :param dataset: instance
    :param dropout: float
    :param lr: float
    :param weight_decay:
    :param num_epochs:
    :param n_splits: number of kFolds
    :param use_gpu: bool
    :param dp: bool
    :param comment: comment in the logs, to filter runs in tensorboard
    :param tb_service_loc: tensorboard service location
    :param batch_size: Dataset args not DataLoader
    :return:
    """
    saved_args = locals()
    seed = int(time.time() % 1e4 * 1e5) if seed is None else seed
    saved_args['random_seed'] = seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

    model_name = model_cls.__name__

    if not cuda_device:
        if device_ids and dp:
            device = device_ids[0]
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    else:
        device = cuda_device

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None
                                       and dp) else device_count

    batch_size = batch_size * device_count

    # TensorBoard
    log_dir_base = get_model_log_dir(comment, model_name)
    if tb_service_loc is not None:
        print("TensorBoard available at http://{1}/#scalars&regexInput={0}".
              format(log_dir_base, tb_service_loc))
    else:
        print("Please set up TensorBoard")

    # model
    criterion = nn.NLLLoss()

    print("Training {0} {1} models for cross validation...".format(
        n_splits, model_name))
    # 1
    # folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0
    # 2
    # folds = GroupKFold(n_splits=n_splits)
    # iter = folds.split(np.zeros(len(dataset)), groups=dataset.data.site_id)
    # 4
    # folds = StratifiedKFold(n_splits=n_splits, random_state=fold_seed, shuffle=True if fold_seed else False)
    # iter = folds.split(np.zeros(len(dataset)), dataset.data.y.numpy(), groups=dataset.data.subject_id)
    # 5
    fold = 0
    iter = multi_site_cv_split(dataset.data.y,
                               dataset.data.site_id,
                               dataset.data.subject_id,
                               n_splits,
                               random_state=fold_seed,
                               shuffle=True if fold_seed else False)

    for train_idx, val_idx in tqdm_notebook(iter, desc='CV', leave=False):
        fold += 1
        liveloss = PlotLosses() if live_loss else None

        # for a specific fold
        if fold_no is not None:
            if fold != fold_no:
                continue

        writer = SummaryWriter(log_dir=osp.join('runs', log_dir_base +
                                                str(fold)))
        model_save_dir = osp.join('saved_models', log_dir_base + str(fold))

        print("creating dataloader tor fold {}".format(fold))

        train_dataset, val_dataset = norm_train_val(dataset, train_idx,
                                                    val_idx)

        model = model_cls(writer)

        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=batch_size,
                                      collate_fn=lambda data_list: data_list,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory)
        val_dataloader = DataLoader(val_dataset,
                                    shuffle=False,
                                    batch_size=batch_size,
                                    collate_fn=lambda data_list: data_list,
                                    num_workers=num_workers,
                                    pin_memory=pin_memory)

        if fold == 1 or fold_no is not None:
            print(model)
            writer.add_text('model_summary', model.__repr__())
            writer.add_text('training_args', str(saved_args))

        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=lr,
                                      betas=(0.9, 0.999),
                                      eps=1e-08,
                                      weight_decay=weight_decay,
                                      amsgrad=False)
        # scheduler_reduce = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=10,
                                           total_epoch=5)
        # scheduler = scheduler_reduce
        # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
        if dp and use_gpu:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = DataParallel(model, device_ids=device_ids)
        elif use_gpu:
            model = model.to(device)

        if saved_model_path is not None:
            model.load_state_dict(torch.load(saved_model_path))

        best_map, patience_counter, best_score = 0.0, 0, np.inf
        for epoch in tqdm_notebook(range(1, num_epochs + 1),
                                   desc='Epoch',
                                   leave=False):
            logs = {}

            # scheduler.step(epoch=epoch, metrics=best_score)

            for phase in ['train', 'validation']:

                if phase == 'train':
                    model.train()
                    dataloader = train_dataloader
                else:
                    model.eval()
                    dataloader = val_dataloader

                # Logging
                running_total_loss = 0.0
                running_corrects = 0
                running_reg_loss = 0.0
                running_nll_loss = 0.0
                epoch_yhat_0, epoch_yhat_1 = torch.tensor([]), torch.tensor([])
                epoch_label, epoch_predicted = torch.tensor([]), torch.tensor(
                    [])

                logging_hist = True if phase == 'train' else False  # once per epoch
                for data_list in tqdm_notebook(dataloader,
                                               desc=phase,
                                               leave=False):

                    # TODO: check devices
                    if dp:
                        data_list = to_cuda(data_list,
                                            (device_ids[0] if device_ids
                                             is not None else 'cuda'))

                    y_hat, domain_yhat, reg = model(data_list)

                    y = torch.tensor([],
                                     dtype=dataset.data.y.dtype,
                                     device=device)
                    domain_y = torch.tensor([],
                                            dtype=dataset.data.site_id.dtype,
                                            device=device)
                    for data in data_list:
                        y = torch.cat([y, data.y.view(-1).to(device)])
                        domain_y = torch.cat(
                            [domain_y,
                             data.site_id.view(-1).to(device)])

                    loss = criterion(y_hat, y)
                    domain_loss = criterion(domain_yhat, domain_y)
                    # domain_loss = -1e-7 * domain_loss
                    # print(domain_loss.item())
                    if domain_cls:
                        total_loss = domain_loss
                        _, predicted = torch.max(domain_yhat, 1)
                        label = domain_y
                    if final_cls:
                        total_loss = loss
                        _, predicted = torch.max(y_hat, 1)
                        label = y
                    if domain_cls and final_cls:
                        total_loss = (loss + domain_loss).sum()
                        _, predicted = torch.max(y_hat, 1)
                        label = y

                    if is_reg:
                        total_loss += reg.sum()

                    if phase == 'train':
                        # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True))
                        optimizer.zero_grad()
                        total_loss.backward()
                        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                        optimizer.step()

                    running_nll_loss += loss.item()
                    running_total_loss += total_loss.item()
                    running_reg_loss += reg.sum().item()
                    running_corrects += (predicted == label).sum().item()

                    epoch_yhat_0 = torch.cat(
                        [epoch_yhat_0, y_hat[:, 0].detach().view(-1).cpu()])
                    epoch_yhat_1 = torch.cat(
                        [epoch_yhat_1, y_hat[:, 1].detach().view(-1).cpu()])
                    epoch_label = torch.cat(
                        [epoch_label,
                         label.detach().float().view(-1).cpu()])
                    epoch_predicted = torch.cat([
                        epoch_predicted,
                        predicted.detach().float().view(-1).cpu()
                    ])

                # precision = sklearn.metrics.precision_score(epoch_label, epoch_predicted, average='micro')
                # recall = sklearn.metrics.recall_score(epoch_label, epoch_predicted, average='micro')
                # f1_score = sklearn.metrics.f1_score(epoch_label, epoch_predicted, average='micro')
                accuracy = sklearn.metrics.accuracy_score(
                    epoch_label, epoch_predicted)
                epoch_total_loss = running_total_loss / dataloader.__len__()
                epoch_nll_loss = running_nll_loss / dataloader.__len__()
                epoch_reg_loss = running_reg_loss / dataloader.__len__()

                # print('epoch {} {}_nll_loss: {}'.format(epoch, phase, epoch_nll_loss))
                writer.add_scalars(
                    'nll_loss', {'{}_nll_loss'.format(phase): epoch_nll_loss},
                    epoch)
                writer.add_scalars('accuracy',
                                   {'{}_accuracy'.format(phase): accuracy},
                                   epoch)
                # writer.add_scalars('{}_APRF'.format(phase),
                #                    {
                #                        'accuracy': accuracy,
                #                        'precision': precision,
                #                        'recall': recall,
                #                        'f1_score': f1_score
                #                    },
                #                    epoch)
                if epoch_reg_loss != 0:
                    writer.add_scalars(
                        'reg_loss'.format(phase),
                        {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch)
                # print(epoch_reg_loss)
                # writer.add_histogram('hist/{}_yhat_0'.format(phase),
                #                      epoch_yhat_0,
                #                      epoch)
                # writer.add_histogram('hist/{}_yhat_1'.format(phase),
                #                      epoch_yhat_1,
                #                      epoch)

                # Save Model & Early Stopping
                if phase == 'validation':
                    model_save_path = model_save_dir + '-{}-{}-{:.3f}-{:.3f}'.format(
                        model_name, epoch, accuracy, epoch_nll_loss)
                    # best score
                    if accuracy > best_map:
                        best_map = accuracy
                        model_save_path = model_save_path + '-best'

                    score = epoch_nll_loss
                    if score < best_score:
                        patience_counter = 0
                        best_score = score
                    else:
                        patience_counter += 1

                    # skip first 10 epoch
                    # best_score = best_score if epoch > 10 else -np.inf

                    if save_model:
                        for th, pfix in zip(
                            [0.8, 0.75, 0.7, 0.5, 0.0],
                            ['-perfect', '-great', '-good', '-bad', '-miss']):
                            if accuracy >= th:
                                model_save_path += pfix
                                break

                        torch.save(model.state_dict(), model_save_path)

                    writer.add_scalars('best_val_accuracy',
                                       {'{}_accuracy'.format(phase): best_map},
                                       epoch)
                    writer.add_scalars(
                        'best_nll_loss',
                        {'{}_nll_loss'.format(phase): best_score}, epoch)

                    writer.add_scalars('learning_rate', {
                        'learning_rate':
                        scheduler.optimizer.param_groups[0]['lr']
                    }, epoch)

                    if patience_counter >= patience:
                        print("Stopped at epoch {}".format(epoch))
                        return

                if live_loss:
                    prefix = ''
                    if phase == 'validation':
                        prefix = 'val_'

                    logs[prefix + 'log loss'] = epoch_nll_loss
                    logs[prefix + 'accuracy'] = accuracy
            if live_loss:
                liveloss.update(logs)
                liveloss.draw()

    print("Done !")
Beispiel #4
0
def main(args):
    batch_size = args.batch_size
    model_fname = args.mod_name

    if multi_gpu and batch_size < torch.cuda.device_count():
        exit('Batch size too small')

    # make a folder for the graphs of this model
    Path(args.output_dir).mkdir(exist_ok=True)
    save_dir = osp.join(args.output_dir, model_fname)
    Path(save_dir).mkdir(exist_ok=True)

    # get dataset and split
    gdata = GraphDataset(root=args.input_dir, bb=args.box_num)
    # merge data from separate files into one contiguous array
    bag = []
    for g in gdata:
        bag += g
    random.Random(0).shuffle(bag)
    bag = bag[:args.num_data]
    # temporary patch to use px, py, pz
    for d in bag:
        d.x = d.x[:, :3]
    # 80:10:10 split datasets
    fulllen = len(bag)
    train_len = int(0.8 * fulllen)
    tv_len = int(0.10 * fulllen)
    train_dataset = bag[:train_len]
    valid_dataset = bag[train_len:train_len + tv_len]
    test_dataset = bag[train_len + tv_len:]
    train_samples = len(train_dataset)
    valid_samples = len(valid_dataset)
    test_samples = len(test_dataset)
    if multi_gpu:
        train_loader = DataListLoader(train_dataset,
                                      batch_size=batch_size,
                                      pin_memory=True,
                                      shuffle=True)
        valid_loader = DataListLoader(valid_dataset,
                                      batch_size=batch_size,
                                      pin_memory=True,
                                      shuffle=False)
        test_loader = DataListLoader(test_dataset,
                                     batch_size=batch_size,
                                     pin_memory=True,
                                     shuffle=False)
    else:
        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  shuffle=True)
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  pin_memory=True,
                                  shuffle=False)
        test_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 pin_memory=True,
                                 shuffle=False)

    # specify loss function
    loss_ftn_obj = LossFunction(args.loss,
                                emd_modname=args.emd_model_name,
                                device=device)

    # create model
    input_dim = 3
    big_dim = 32
    hidden_dim = args.lat_dim
    lr = args.lr
    patience = args.patience

    if args.model == 'MetaLayerGAE':
        model = models.GNNAutoEncoder()
    else:
        if args.model[-3:] == 'EMD':
            model = getattr(models,
                            args.model)(input_dim=input_dim,
                                        big_dim=big_dim,
                                        hidden_dim=hidden_dim,
                                        emd_modname=args.emd_model_name)
        else:
            model = getattr(models, args.model)(input_dim=input_dim,
                                                big_dim=big_dim,
                                                hidden_dim=hidden_dim)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=4)

    valid_losses = []
    train_losses = []
    start_epoch = 0
    n_epochs = 200

    # load in model
    modpath = osp.join(save_dir, model_fname + '.best.pth')
    try:
        model.load_state_dict(torch.load(modpath))
        train_losses, valid_losses, start_epoch = torch.load(
            osp.join(save_dir, 'losses.pt'))
        print('Loaded model')
        best_valid_loss = test(model, valid_loader, valid_samples, batch_size,
                               loss_ftn_obj)
        print(f'Saved model valid loss: {best_valid_loss}')
    except:
        print('Creating new model')
        best_valid_loss = 9999999
    if multi_gpu:
        model = DataParallel(model)
    model.to(torch.device(device))

    # Training loop
    stale_epochs = 0
    loss = best_valid_loss
    for epoch in range(start_epoch, n_epochs):

        if multi_gpu:
            loss = train_parallel(model, optimizer, train_loader,
                                  train_samples, batch_size, loss_ftn_obj)
            valid_loss = test_parallel(model, valid_loader, valid_samples,
                                       batch_size, loss_ftn_obj)
        else:
            loss = train(model, optimizer, train_loader, train_samples,
                         batch_size, loss_ftn_obj)
            valid_loss = test(model, valid_loader, valid_samples, batch_size,
                              loss_ftn_obj)

        scheduler.step(valid_loss)
        train_losses.append(loss)
        valid_losses.append(valid_loss)
        print('Epoch: {:02d}, Training Loss:   {:.4f}'.format(epoch, loss))
        print('               Validation Loss: {:.4f}'.format(valid_loss))

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            print('New best model saved to:', modpath)
            if multi_gpu:
                torch.save(model.module.state_dict(), modpath)
            else:
                torch.save(model.state_dict(), modpath)
            torch.save((train_losses, valid_losses, epoch + 1),
                       osp.join(save_dir, 'losses.pt'))
            stale_epochs = 0
        else:
            stale_epochs += 1
            print(
                f'Stale epoch: {stale_epochs}\nBest: {best_valid_loss}\nCurr: {valid_loss}'
            )
        if stale_epochs >= patience:
            print('Early stopping after %i stale epochs' % patience)
            break

    # model training done
    train_epochs = list(range(epoch + 1))
    early_stop_epoch = epoch - stale_epochs
    loss_curves(train_epochs, early_stop_epoch, train_losses, valid_losses,
                save_dir)

    # compare input and reconstructions
    model.load_state_dict(torch.load(modpath))
    input_fts = []
    reco_fts = []
    for t in valid_loader:
        model.eval()
        if isinstance(t, list):
            for d in t:
                input_fts.append(d.x)
        else:
            input_fts.append(t.x)
            t.to(device)
        reco_out = model(t)
        if isinstance(reco_out, tuple):
            reco_out = reco_out[0]
        reco_fts.append(reco_out.cpu().detach())
    input_fts = torch.cat(input_fts)
    reco_fts = torch.cat(reco_fts)
    plot_reco_difference(
        input_fts, reco_fts, model_fname,
        osp.join(save_dir, 'reconstruction_post_train', 'valid'))

    input_fts = []
    reco_fts = []
    for t in test_loader:
        model.eval()
        if isinstance(t, list):
            for d in t:
                input_fts.append(d.x)
        else:
            input_fts.append(t.x)
            t.to(device)
        reco_out = model(t)
        if isinstance(reco_out, tuple):
            reco_out = reco_out[0]
        reco_fts.append(reco_out.cpu().detach())
    input_fts = torch.cat(input_fts)
    reco_fts = torch.cat(reco_fts)
    plot_reco_difference(
        input_fts, reco_fts, model_fname,
        osp.join(save_dir, 'reconstruction_post_train', 'test'))
    print('Completed')
    args = parser.parse_args()

    # Load dataset

    dataset = PygNodePropPredDataset(name=args.dataset_name)
    data = dataset[0]
    dataset_test(data)

    if args.multi_gpu:
        # Unit test: GPU number verification

        # Prepare model
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model = parse_model_name(args.model, dataset)
        model = DataParallel(model)
        model = model.to(device)

        #Split graph into subgraphs
        if args.subgraph_scheme == 'cluster':
            # Split data into subgraphs using cluster methods
            data_list = list(ClusterData(data, num_parts=args.num_parts))
        elif args.subgraph_scheme == 'neighbor':
            data_list = list(
                NeighborSubgraphLoader(data,
                                       batch_size=args.neighbor_batch_size))
            print(
                f'Using neighbor sampling | number of subgraphs: {len(data_list)}'
            )

        # Run the model for each batch size setups
        batch_sizes = np.array(list(range(1, 65))) * 4
Beispiel #6
0
def train():

    # set the input channel dims based on featurization type
    if args.feature_type == "pybel":
        feature_size = 20
    else:
        feature_size = 75

    print("found {} datasets in input train-data".format(len(args.train_data)))
    train_dataset_list = []
    val_dataset_list = []

    for data in args.train_data:
        train_dataset_list.append(
            PDBBindDataset(
                data_file=data,
                dataset_name=args.dataset_name,
                feature_type=args.feature_type,
                preprocessing_type=args.preprocessing_type,
                output_info=True,
                use_docking=args.use_docking,
            ))

    for data in args.val_data:
        val_dataset_list.append(
            PDBBindDataset(
                data_file=data,
                dataset_name=args.dataset_name,
                feature_type=args.feature_type,
                preprocessing_type=args.preprocessing_type,
                output_info=True,
                use_docking=args.use_docking,
            ))

    train_dataset = ConcatDataset(train_dataset_list)
    val_dataset = ConcatDataset(val_dataset_list)

    train_dataloader = DataListLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )  # just to keep batch sizes even, since shuffling is used

    val_dataloader = DataListLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )

    tqdm.write("{} complexes in training dataset".format(len(train_dataset)))
    tqdm.write("{} complexes in validation dataset".format(len(val_dataset)))

    model = GeometricDataParallel(
        PotentialNetParallel(
            in_channels=feature_size,
            out_channels=1,
            covalent_gather_width=args.covalent_gather_width,
            non_covalent_gather_width=args.non_covalent_gather_width,
            covalent_k=args.covalent_k,
            non_covalent_k=args.non_covalent_k,
            covalent_neighbor_threshold=args.covalent_threshold,
            non_covalent_neighbor_threshold=args.non_covalent_threshold,
        )).float()

    model.train()
    model.to(0)
    tqdm.write(str(model))
    tqdm.write("{} trainable parameters.".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))
    tqdm.write("{} total parameters.".format(
        sum(p.numel() for p in model.parameters())))

    criterion = nn.MSELoss().float()
    optimizer = Adam(model.parameters(), lr=args.lr)

    best_checkpoint_dict = None
    best_checkpoint_epoch = 0
    best_checkpoint_step = 0
    best_checkpoint_r2 = -9e9
    step = 0
    for epoch in range(args.epochs):
        losses = []
        for batch in tqdm(train_dataloader):
            batch = [x for x in batch if x is not None]
            if len(batch) < 1:
                print("empty batch, skipping to next batch")
                continue
            optimizer.zero_grad()

            data = [x[2] for x in batch]
            y_ = model(data)
            y = torch.cat([x[2].y for x in batch])

            loss = criterion(y.float(), y_.cpu().float())
            losses.append(loss.cpu().data.item())
            loss.backward()

            y_true = y.cpu().data.numpy()
            y_pred = y_.cpu().data.numpy()

            r2 = r2_score(y_true=y_true, y_pred=y_pred)
            mae = mean_absolute_error(y_true=y_true, y_pred=y_pred)

            pearsonr = stats.pearsonr(y_true.reshape(-1), y_pred.reshape(-1))
            spearmanr = stats.spearmanr(y_true.reshape(-1), y_pred.reshape(-1))

            tqdm.write(
                "epoch: {}\tloss:{:0.4f}\tr2: {:0.4f}\t pearsonr: {:0.4f}\tspearmanr: {:0.4f}\tmae: {:0.4f}\tpred stdev: {:0.4f}"
                "\t pred mean: {:0.4f} \tcovalent_threshold: {:0.4f} \tnon covalent threshold: {:0.4f}"
                .format(
                    epoch,
                    loss.cpu().data.numpy(),
                    r2,
                    float(pearsonr[0]),
                    float(spearmanr[0]),
                    float(mae),
                    np.std(y_pred),
                    np.mean(y_pred),
                    model.module.covalent_neighbor_threshold.t.cpu().data.item(
                    ),
                    model.module.non_covalent_neighbor_threshold.t.cpu().data.
                    item(),
                ))

            if args.checkpoint:
                if step % args.checkpoint_iter == 0:
                    checkpoint_dict = checkpoint_model(
                        model,
                        val_dataloader,
                        epoch,
                        step,
                        args.checkpoint_dir +
                        "/model-epoch-{}-step-{}.pth".format(epoch, step),
                    )
                    if checkpoint_dict["validate_dict"][
                            "r2"] > best_checkpoint_r2:
                        best_checkpoint_step = step
                        best_checkpoint_epoch = epoch
                        best_checkpoint_r2 = checkpoint_dict["validate_dict"][
                            "r2"]
                        best_checkpoint_dict = checkpoint_dict

            optimizer.step()
            step += 1

        if args.checkpoint:
            checkpoint_dict = checkpoint_model(
                model,
                val_dataloader,
                epoch,
                step,
                args.checkpoint_dir +
                "/model-epoch-{}-step-{}.pth".format(epoch, step),
            )
            if checkpoint_dict["validate_dict"]["r2"] > best_checkpoint_r2:
                best_checkpoint_step = step
                best_checkpoint_epoch = epoch
                best_checkpoint_r2 = checkpoint_dict["validate_dict"]["r2"]
                best_checkpoint_dict = checkpoint_dict

    if args.checkpoint:
        # once broken out of the loop, save last model
        checkpoint_dict = checkpoint_model(
            model,
            val_dataloader,
            epoch,
            step,
            args.checkpoint_dir +
            "/model-epoch-{}-step-{}.pth".format(epoch, step),
        )

        if checkpoint_dict["validate_dict"]["r2"] > best_checkpoint_r2:
            best_checkpoint_step = step
            best_checkpoint_epoch = epoch
            best_checkpoint_r2 = checkpoint_dict["validate_dict"]["r2"]
            best_checkpoint_dict = checkpoint_dict

    if args.checkpoint:
        torch.save(best_checkpoint_dict,
                   args.checkpoint_dir + "/best_checkpoint.pth")
    print("best training checkpoint epoch {}/step {} with r2: {}".format(
        best_checkpoint_epoch, best_checkpoint_step, best_checkpoint_r2))
Beispiel #7
0
def train_cross_validation(model_cls,
                           dataset,
                           num_clusters,
                           dropout=0.0,
                           lr=1e-4,
                           weight_decay=1e-2,
                           num_epochs=200,
                           n_splits=10,
                           use_gpu=True,
                           dp=False,
                           ddp=True,
                           comment='',
                           tb_service_loc='192.168.192.57:6006',
                           batch_size=1,
                           num_workers=0,
                           pin_memory=False,
                           cuda_device=None,
                           fold_no=None,
                           saved_model_path=None,
                           device_ids=None,
                           patience=50,
                           seed=None,
                           save_model=True,
                           c_reg=0,
                           base_log_dir='runs',
                           base_model_save_dir='saved_models'):
    """
    :param c_reg:
    :param save_model: bool
    :param seed:
    :param patience: for early stopping
    :param device_ids: for ddp
    :param saved_model_path:
    :param fold_no:
    :param ddp: DDP
    :param cuda_device:
    :param pin_memory: DataLoader args https://devblogs.nvidia.com/how-optimize-data-transfers-cuda-cc/
    :param num_workers: DataLoader args
    :param model_cls: pytorch Module cls
    :param dataset: pytorch Dataset cls
    :param dropout:
    :param lr:
    :param weight_decay:
    :param num_epochs:
    :param n_splits: number of kFolds
    :param use_gpu: bool
    :param dp: bool
    :param comment: comment in the logs, to filter runs in tensorboard
    :param tb_service_loc: tensorboard service location
    :param batch_size: Dataset args not DataLoader
    :return:
    """
    saved_args = locals()
    seed = int(time.time() % 1e4 * 1e5) if seed is None else seed
    saved_args['random_seed'] = seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)

    if ddp and not torch.distributed.is_initialized():  # initialize ddp
        dist.init_process_group('nccl',
                                init_method='tcp://localhost:{}'.format(
                                    find_open_port()),
                                world_size=1,
                                rank=0)

    model_name = model_cls.__name__

    if not cuda_device:
        if device_ids and (ddp or dp):
            device = device_ids[0]
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    else:
        device = cuda_device

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None and
                                       (dp or ddp)) else device_count
    if device_count > 1:
        print("Let's use", device_count, "GPUs!")

    # batch_size = batch_size * device_count

    log_dir_base = get_model_log_dir(comment, model_name)
    if tb_service_loc is not None:
        print("TensorBoard available at http://{1}/#scalars&regexInput={0}".
              format(log_dir_base, tb_service_loc))
    else:
        print("Please set up TensorBoard")

    criterion = nn.CrossEntropyLoss()

    # get test set
    folds = StratifiedKFold(n_splits=n_splits, shuffle=False)
    train_val_idx, test_idx = list(
        folds.split(np.zeros(len(dataset)), dataset.data.y.numpy()))[0]
    test_dataset = dataset.__indexing__(test_idx)
    train_val_dataset = dataset.__indexing__(train_val_idx)

    print("Training {0} {1} models for cross validation...".format(
        n_splits, model_name))
    # folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0
    folds = StratifiedKFold(n_splits=n_splits, shuffle=False)
    iter = folds.split(np.zeros(len(train_val_dataset)),
                       train_val_dataset.data.y.numpy())
    fold = 0

    for train_idx, val_idx in tqdm_notebook(iter, desc='CV', leave=False):

        fold += 1
        if fold_no is not None:
            if fold != fold_no:
                continue

        writer = SummaryWriter(log_dir=osp.join(base_log_dir, log_dir_base +
                                                str(fold)))
        model_save_dir = osp.join(base_model_save_dir,
                                  log_dir_base + str(fold))

        print("creating dataloader tor fold {}".format(fold))

        model = model_cls(writer,
                          num_clusters=num_clusters,
                          in_dim=dataset.data.x.shape[1],
                          out_dim=int(dataset.data.y.max() + 1),
                          dropout=dropout)

        # My Batch

        train_dataset = train_val_dataset.__indexing__(train_idx)
        val_dataset = train_val_dataset.__indexing__(val_idx)

        train_dataset = dataset_gather(
            train_dataset,
            seed=0,
            n_repeat=1,
            n_splits=int(len(train_dataset) / batch_size) + 1)
        val_dataset = dataset_gather(
            val_dataset,
            seed=0,
            n_repeat=1,
            n_splits=int(len(val_dataset) / batch_size) + 1)

        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=device_count,
                                      collate_fn=lambda data_list: data_list,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory)
        val_dataloader = DataLoader(val_dataset,
                                    shuffle=False,
                                    batch_size=device_count,
                                    collate_fn=lambda data_list: data_list,
                                    num_workers=num_workers,
                                    pin_memory=pin_memory)

        # if fold == 1 or fold_no is not None:
        print(model)
        writer.add_text('model_summary', model.__repr__())
        writer.add_text('training_args', str(saved_args))

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=weight_decay,
                                     amsgrad=False)
        # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
        if ddp:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=device_ids)
        elif dp and use_gpu:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = DataParallel(model, device_ids=device_ids)
        elif use_gpu:
            model = model.to(device)

        if saved_model_path is not None:
            model.load_state_dict(torch.load(saved_model_path))

        best_map, patience_counter, best_score = 0.0, 0, -np.inf
        for epoch in tqdm_notebook(range(1, num_epochs + 1),
                                   desc='Epoch',
                                   leave=False):

            for phase in ['train', 'validation']:

                if phase == 'train':
                    model.train()
                    dataloader = train_dataloader
                else:
                    model.eval()
                    dataloader = val_dataloader

                # Logging
                running_total_loss = 0.0
                running_corrects = 0
                running_reg_loss = 0.0
                running_nll_loss = 0.0
                epoch_yhat_0, epoch_yhat_1 = torch.tensor([]), torch.tensor([])
                epoch_label, epoch_predicted = torch.tensor([]), torch.tensor(
                    [])

                for data_list in tqdm_notebook(dataloader,
                                               desc=phase,
                                               leave=False):

                    # TODO: check devices
                    if dp:
                        data_list = to_cuda(data_list,
                                            (device_ids[0] if device_ids
                                             is not None else 'cuda'))

                    y_hat, reg = model(data_list)
                    # y_hat = y_hat.reshape(batch_size, -1)

                    y = torch.tensor([],
                                     dtype=dataset.data.y.dtype,
                                     device=device)
                    for data in data_list:
                        y = torch.cat([y, data.y.view(-1).to(device)])

                    loss = criterion(y_hat, y)
                    reg_loss = -reg
                    total_loss = (loss + reg_loss * c_reg).sum()

                    if phase == 'train':
                        # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True))
                        optimizer.zero_grad()
                        total_loss.backward(retain_graph=True)
                        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                        optimizer.step()

                    _, predicted = torch.max(y_hat, 1)
                    label = y

                    running_nll_loss += loss.item()
                    running_total_loss += total_loss.item()
                    running_reg_loss += reg.sum().item()
                    running_corrects += (predicted == label).sum().item()

                    epoch_yhat_0 = torch.cat(
                        [epoch_yhat_0, y_hat[:, 0].detach().view(-1).cpu()])
                    epoch_yhat_1 = torch.cat(
                        [epoch_yhat_1, y_hat[:, 1].detach().view(-1).cpu()])
                    epoch_label = torch.cat(
                        [epoch_label,
                         label.detach().cpu().float()])
                    epoch_predicted = torch.cat(
                        [epoch_predicted,
                         predicted.detach().cpu().float()])

                # precision = sklearn.metrics.precision_score(epoch_label, epoch_predicted, average='micro')
                # recall = sklearn.metrics.recall_score(epoch_label, epoch_predicted, average='micro')
                # f1_score = sklearn.metrics.f1_score(epoch_label, epoch_predicted, average='micro')
                accuracy = sklearn.metrics.accuracy_score(
                    epoch_label, epoch_predicted)
                epoch_total_loss = running_total_loss / dataloader.__len__()
                epoch_nll_loss = running_nll_loss / dataloader.__len__()
                epoch_reg_loss = running_reg_loss / dataloader.dataset.__len__(
                )

                writer.add_scalars(
                    'nll_loss', {'{}_nll_loss'.format(phase): epoch_nll_loss},
                    epoch)
                writer.add_scalars('accuracy',
                                   {'{}_accuracy'.format(phase): accuracy},
                                   epoch)
                # writer.add_scalars('{}_APRF'.format(phase),
                #                    {
                #                        'accuracy': accuracy,
                #                        'precision': precision,
                #                        'recall': recall,
                #                        'f1_score': f1_score
                #                    },
                #                    epoch)
                if epoch_reg_loss != 0:
                    writer.add_scalars(
                        'reg_loss'.format(phase),
                        {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch)
                # writer.add_histogram('hist/{}_yhat_0'.format(phase),
                #                      epoch_yhat_0,
                #                      epoch)
                # writer.add_histogram('hist/{}_yhat_1'.format(phase),
                #                      epoch_yhat_1,
                #                      epoch)

                # Save Model & Early Stopping
                if phase == 'validation':
                    model_save_path = model_save_dir + '-{}-{}-{:.3f}-{:.3f}'.format(
                        model_name, epoch, accuracy, epoch_nll_loss)
                    if accuracy > best_map:
                        best_map = accuracy
                        model_save_path = model_save_path + '-best'

                    score = -epoch_nll_loss
                    if score > best_score:
                        patience_counter = 0
                        best_score = score
                    else:
                        patience_counter += 1

                    # skip 10 epoch
                    # best_score = best_score if epoch > 10 else -np.inf

                    if save_model:
                        for th, pfix in zip(
                            [0.8, 0.75, 0.7, 0.5, 0.0],
                            ['-perfect', '-great', '-good', '-bad', '-miss']):
                            if accuracy >= th:
                                model_save_path += pfix
                                break
                        if epoch > 10:
                            torch.save(model.state_dict(), model_save_path)

                    writer.add_scalars('best_val_accuracy',
                                       {'{}_accuracy'.format(phase): best_map},
                                       epoch)
                    writer.add_scalars(
                        'best_nll_loss',
                        {'{}_nll_loss'.format(phase): -best_score}, epoch)

                    if patience_counter >= patience:
                        print("Stopped at epoch {}".format(epoch))
                        return

    print("Done !")
Beispiel #8
0
def train_cummunity_detection(model_cls,
                              dataset,
                              dropout=0.0,
                              lr=1e-3,
                              weight_decay=1e-2,
                              num_epochs=200,
                              n_splits=10,
                              use_gpu=True,
                              dp=False,
                              ddp=False,
                              comment='',
                              tb_service_loc='192.168.192.57:6006',
                              batch_size=1,
                              num_workers=0,
                              pin_memory=False,
                              cuda_device=None,
                              ddp_port='23456',
                              fold_no=None,
                              device_ids=None,
                              patience=20,
                              seed=None,
                              save_model=False,
                              supervised=False):
    """
    :param save_model: bool
    :param seed:
    :param patience: for early stopping
    :param device_ids: for ddp
    :param saved_model_path:
    :param fold_no:
    :param ddp_port:
    :param ddp: DDP
    :param cuda_device:
    :param pin_memory: DataLoader args https://devblogs.nvidia.com/how-optimize-data-transfers-cuda-cc/
    :param num_workers: DataLoader args
    :param model_cls: pytorch Module cls
    :param dataset: pytorch Dataset cls
    :param dropout:
    :param lr:
    :param weight_decay:
    :param num_epochs:
    :param n_splits: number of kFolds
    :param use_gpu: bool
    :param dp: bool
    :param comment: comment in the logs, to filter runs in tensorboard
    :param tb_service_loc: tensorboard service location
    :param batch_size: Dataset args not DataLoader
    :return:
    """

    saved_args = locals()
    seed = int(time.time() % 1e4 * 1e5) if seed is None else seed
    saved_args['random_seed'] = seed

    torch.manual_seed(seed)
    np.random.seed(seed)
    if use_gpu:
        torch.cuda.manual_seed_all(seed)

    if ddp and not torch.distributed.is_initialized():  # initialize ddp
        dist.init_process_group(
            'nccl',
            init_method='tcp://localhost:{}'.format(ddp_port),
            world_size=1,
            rank=0)

    model_name = model_cls.__name__

    if not cuda_device:
        if device_ids and (ddp or dp):
            device = device_ids[0]
        else:
            device = torch.device(
                'cuda' if torch.cuda.is_available() and use_gpu else 'cpu')
    else:
        device = cuda_device

    device_count = torch.cuda.device_count() if dp else 1
    device_count = len(device_ids) if (device_ids is not None and
                                       (dp or ddp)) else device_count
    if device_count > 1:
        print("Let's use", device_count, "GPUs!")

    # batch_size = batch_size * device_count

    log_dir_base = get_model_log_dir(comment, model_name)
    if tb_service_loc is not None:
        print("TensorBoard available at http://{1}/#scalars&regexInput={0}".
              format(log_dir_base, tb_service_loc))
    else:
        print("Please set up TensorBoard")

    print("Training {0} {1} models for cross validation...".format(
        n_splits, model_name))
    folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0
    print(dataset.__len__())

    for train_idx, test_idx in tqdm_notebook(folds.split(
            list(range(dataset.__len__())), list(range(dataset.__len__()))),
                                             desc='models',
                                             leave=False):
        fold += 1
        if fold_no is not None:
            if fold != fold_no:
                continue

        writer = SummaryWriter(log_dir=osp.join('runs', log_dir_base +
                                                str(fold)))
        model_save_dir = osp.join('saved_models', log_dir_base + str(fold))

        print("creating dataloader tor fold {}".format(fold))

        model = model_cls(writer, dropout=dropout)

        # My Batch
        train_dataset = dataset.__indexing__(train_idx)
        test_dataset = dataset.__indexing__(test_idx)

        train_dataset = dataset_gather(train_dataset,
                                       n_repeat=1,
                                       n_splits=int(
                                           len(train_dataset) / batch_size))

        train_dataloader = DataLoader(train_dataset,
                                      shuffle=True,
                                      batch_size=device_count,
                                      collate_fn=lambda data_list: data_list,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory)
        test_dataloader = DataLoader(test_dataset,
                                     shuffle=False,
                                     batch_size=device_count,
                                     collate_fn=lambda data_list: data_list,
                                     num_workers=num_workers,
                                     pin_memory=pin_memory)

        print(model)
        writer.add_text('model_summary', model.__repr__())
        writer.add_text('training_args', str(saved_args))

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=weight_decay,
                                     amsgrad=False)
        # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
        if ddp:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=device_ids)
        elif dp and use_gpu:
            model = model.cuda() if device_ids is None else model.to(
                device_ids[0])
            model = DataParallel(model, device_ids=device_ids)
        elif use_gpu:
            model = model.to(device)

        for epoch in tqdm_notebook(range(1, num_epochs + 1),
                                   desc='Epoch',
                                   leave=False):

            for phase in ['train', 'validation']:

                if phase == 'train':
                    model.train()
                    dataloader = train_dataloader
                else:
                    model.eval()
                    dataloader = test_dataloader

                # Logging
                running_total_loss = 0.0
                running_reg_loss = 0.0
                running_overlap = 0.0

                for data_list in tqdm_notebook(dataloader,
                                               desc=phase,
                                               leave=False):

                    # TODO: check devices
                    if dp:
                        data_list = to_cuda(data_list,
                                            (device_ids[0] if device_ids
                                             is not None else 'cuda'))

                    y_hat, reg = model(data_list)

                    y = torch.tensor([],
                                     dtype=dataset.data.y.dtype,
                                     device=device)
                    for data in data_list:
                        y = torch.cat([y, data.y.view(-1).to(device)])

                    if supervised:
                        loss = permutation_invariant_loss(y_hat, y)
                        # criterion = nn.NLLLoss()
                        # loss = criterion(y_hat, y)
                    else:
                        loss = -reg
                    total_loss = loss

                    if phase == 'train':
                        # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True))
                        optimizer.zero_grad()
                        total_loss.backward(retain_graph=True)
                        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
                        optimizer.step()

                    _, predicted = torch.max(y_hat, 1)
                    label = y

                    if supervised:
                        overlap_score = normalized_overlap(
                            label.int().cpu().numpy(),
                            predicted.int().cpu().numpy(), 0.25)
                        # overlap_score = overlap(label.int().cpu().numpy(), predicted.int().cpu().numpy())
                        running_overlap += overlap_score
                        print(reg, overlap_score, loss)

                    running_total_loss += total_loss.item()
                    running_reg_loss += reg.sum().item()

                epoch_total_loss = running_total_loss / dataloader.__len__()
                epoch_reg_loss = running_reg_loss / dataloader.dataset.__len__(
                )
                if supervised:
                    epoch_overlap = running_overlap / dataloader.__len__()
                    writer.add_scalars(
                        'overlap'.format(phase),
                        {'{}_overlap'.format(phase): epoch_overlap}, epoch)

                writer.add_scalars(
                    'reg_loss'.format(phase),
                    {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch)

    print("Done !")