def init_model(model_cls, log_dir_base, fold_no, device_ids=None, use_gpu=False, dp=False, ddp=False, tb_dir='runs', lr=1e-3, weight_decay=1e-2): writer = SummaryWriter(log_dir=osp.join(tb_dir, log_dir_base)) model = model_cls(writer) writer.add_text('model_summary', model.__repr__()) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False) # scheduler_reduce = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) # scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=5) # scheduler = scheduler_reduce # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) if dp and use_gpu: model = model.cuda() if device_ids is None else model.to(device_ids[0]) model = DataParallel(model, device_ids=device_ids) elif use_gpu: model = model.to(device_ids[0]) device_count = torch.cuda.device_count() if dp else 1 device_count = len(device_ids) if (device_ids is not None and dp) else device_count return model, optimizer, writer, device_count
def model_training(data_list_train, data_list_test, epochs, acc_epoch, acc_epoch2, save_model_epochs, validation_epoch, batchsize, logfilename, load_checkpoint= None): #logging logging.basicConfig(level=logging.DEBUG, filename='./logfiles/'+logfilename, filemode="w+", format="%(message)s") trainloader = DataListLoader(data_list_train, batch_size=batchsize, shuffle=True) testloader = DataListLoader(data_list_test, batch_size=batchsize, shuffle=True) device = torch.device('cuda') complete_net = completeNet() complete_net = DataParallel(complete_net) complete_net = complete_net.to(device) #train parameters weights = [10, 1] optimizer = torch.optim.Adam(complete_net.parameters(), lr=0.001, weight_decay=0.001) #resume training initial_epoch=1 if load_checkpoint!=None: checkpoint = torch.load(load_checkpoint) complete_net.load_state_dict(checkpoint['model_state_dict'], strict=False) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) initial_epoch = checkpoint['epoch']+1 loss = checkpoint['loss'] complete_net.train() for epoch in range(initial_epoch, epochs+1): epoch_total=0 epoch_total_ones= 0 epoch_total_zeros= 0 epoch_correct=0 epoch_correct_ones= 0 epoch_correct_zeros= 0 running_loss= 0 batches_num=0 for batch in trainloader: batch_total=0 batch_total_ones= 0 batch_total_zeros= 0 batch_correct= 0 batch_correct_ones= 0 batch_correct_zeros= 0 batches_num+=1 # Forward-Backpropagation output, output2, ground_truth, ground_truth2, det_num, tracklet_num= complete_net(batch) optimizer.zero_grad() loss = weighted_binary_cross_entropy(output, ground_truth, weights) loss.backward() optimizer.step() ##Accuracy if epoch%acc_epoch==0 and epoch!=0: # Hungarian method, clean up cleaned_output= hungarian(output2, ground_truth2, det_num, tracklet_num) batch_total += cleaned_output.size(0) ones= torch.tensor([1 for x in cleaned_output]).to(device) zeros = torch.tensor([0 for x in cleaned_output]).to(device) batch_total_ones += (cleaned_output == ones).sum().item() batch_total_zeros += (cleaned_output == zeros).sum().item() batch_correct += (cleaned_output == ground_truth2).sum().item() temp1 = (cleaned_output == ground_truth2) temp2 = (cleaned_output == ones) batch_correct_ones += (temp1 & temp2).sum().item() temp3 = (cleaned_output == zeros) batch_correct_zeros += (temp1 & temp3).sum().item() epoch_total += batch_total epoch_total_ones += batch_total_ones epoch_total_zeros += batch_total_zeros epoch_correct += batch_correct epoch_correct_ones += batch_correct_ones epoch_correct_zeros += batch_correct_zeros if loss.item()!=loss.item(): print("Error") break if batch_total_ones != 0 and batch_total_zeros != 0 and epoch%acc_epoch==0 and epoch!=0: print('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' % (epoch, batches_num, loss.item(), 100 * batch_correct / batch_total, 100 * batch_correct_ones / batch_total_ones, 100 * batch_correct_zeros / batch_total_zeros)) logging.info('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' % (epoch, batches_num, loss.item(), 100 * batch_correct / batch_total, 100 * batch_correct_ones / batch_total_ones, 100 * batch_correct_zeros / batch_total_zeros)) else: print('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f |' % (epoch, batches_num, loss.item())) logging.info('Epoch: [%d] | Batch: [%d] | Training_Loss: %.3f |' % (epoch, batches_num, loss.item())) running_loss += loss.item() if loss.item()!=loss.item(): print("Error") break if epoch_total_ones!=0 and epoch_total_zeros!=0 and epoch%acc_epoch==0 and epoch!=0: print('Epoch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' % (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \ epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros)) logging.info('Epoch: [%d] | Training_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' % (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \ epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros)) else: print('Epoch: [%d] | Training_Loss: %.3f |' % (epoch, running_loss / batches_num)) logging.info('Epoch: [%d] | Training_Loss: %.3f |' % (epoch, running_loss / batches_num)) # save model if epoch%save_model_epochs==0 and epoch!=0: torch.save({ 'epoch': epoch, 'model_state_dict': complete_net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': running_loss, }, './models/epoch_'+str(epoch)+'.pth') #validation if epoch%validation_epoch==0 and epoch!=0: with torch.no_grad(): epoch_total=0 epoch_total_ones= 0 epoch_total_zeros= 0 epoch_correct=0 epoch_correct_ones= 0 epoch_correct_zeros= 0 running_loss= 0 batches_num=0 for batch in testloader: batch_total=0 batch_total_ones= 0 batch_total_zeros= 0 batch_correct= 0 batch_correct_ones= 0 batch_correct_zeros= 0 batches_num+=1 output, output2, ground_truth, ground_truth2, det_num, tracklet_num = complete_net(batch) loss = weighted_binary_cross_entropy(output, ground_truth, weights) running_loss += loss.item() ##Accuracy if epoch%acc_epoch2==0 and epoch!=0: # Hungarian method, clean up cleaned_output= hungarian(output2, ground_truth2, det_num, tracklet_num) batch_total += cleaned_output.size(0) ones= torch.tensor([1 for x in cleaned_output]).to(device) zeros = torch.tensor([0 for x in cleaned_output]).to(device) batch_total_ones += (cleaned_output == ones).sum().item() batch_total_zeros += (cleaned_output == zeros).sum().item() batch_correct += (cleaned_output == ground_truth2).sum().item() temp1 = (cleaned_output == ground_truth2) temp2 = (cleaned_output == ones) batch_correct_ones += (temp1 & temp2).sum().item() temp3 = (cleaned_output == zeros) batch_correct_zeros += (temp1 & temp3).sum().item() epoch_total += batch_total epoch_total_ones += batch_total_ones epoch_total_zeros += batch_total_zeros epoch_correct += batch_correct epoch_correct_ones += batch_correct_ones epoch_correct_zeros += batch_correct_zeros if epoch_total_ones!=0 and epoch_total_zeros!=0 and epoch%acc_epoch2==0 and epoch!=0: print('Epoch: [%d] | Validation_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' % (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \ epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros)) logging.info('Epoch: [%d] | Validation_Loss: %.3f | Total_Accuracy: %.3f | Ones_Accuracy: %.3f | Zeros_Accuracy: %.3f |' % (epoch, running_loss / batches_num, 100 * epoch_correct / epoch_total, 100 * \ epoch_correct_ones / epoch_total_ones, 100 * epoch_correct_zeros / epoch_total_zeros)) else: print('Epoch: [%d] | Validation_Loss: %.3f |' % (epoch, running_loss / batches_num)) logging.info('Epoch: [%d] | Validation_Loss: %.3f |' % (epoch, running_loss / batches_num))
def train_cross_validation(model_cls, dataset, dropout=0.0, lr=1e-3, weight_decay=1e-2, num_epochs=200, n_splits=10, use_gpu=True, dp=False, ddp=False, comment='', tb_service_loc='192.168.192.57:6007', batch_size=1, num_workers=0, pin_memory=False, cuda_device=None, tb_dir='runs', model_save_dir='saved_models', res_save_dir='res', fold_no=None, saved_model_path=None, device_ids=None, patience=20, seed=None, fold_seed=None, save_model=False, is_reg=True, live_loss=True, domain_cls=True, final_cls=True): """ :type fold_seed: int :param live_loss: bool :param is_reg: bool :param save_model: bool :param seed: :param patience: for early stopping :param device_ids: for ddp :param saved_model_path: :param fold_no: int :param ddp_port: str :param ddp: DDP :param cuda_device: list of int :param pin_memory: bool, DataLoader args :param num_workers: int, DataLoader args :param model_cls: pytorch Module cls :param dataset: instance :param dropout: float :param lr: float :param weight_decay: :param num_epochs: :param n_splits: number of kFolds :param use_gpu: bool :param dp: bool :param comment: comment in the logs, to filter runs in tensorboard :param tb_service_loc: tensorboard service location :param batch_size: Dataset args not DataLoader :return: """ saved_args = locals() seed = int(time.time() % 1e4 * 1e5) if seed is None else seed saved_args['random_seed'] = seed torch.manual_seed(seed) np.random.seed(seed) if use_gpu: torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False model_name = model_cls.__name__ if not cuda_device: if device_ids and dp: device = device_ids[0] else: device = torch.device( 'cuda' if torch.cuda.is_available() and use_gpu else 'cpu') else: device = cuda_device device_count = torch.cuda.device_count() if dp else 1 device_count = len(device_ids) if (device_ids is not None and dp) else device_count batch_size = batch_size * device_count # TensorBoard log_dir_base = get_model_log_dir(comment, model_name) if tb_service_loc is not None: print("TensorBoard available at http://{1}/#scalars®exInput={0}". format(log_dir_base, tb_service_loc)) else: print("Please set up TensorBoard") # model criterion = nn.NLLLoss() print("Training {0} {1} models for cross validation...".format( n_splits, model_name)) # 1 # folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0 # 2 # folds = GroupKFold(n_splits=n_splits) # iter = folds.split(np.zeros(len(dataset)), groups=dataset.data.site_id) # 4 # folds = StratifiedKFold(n_splits=n_splits, random_state=fold_seed, shuffle=True if fold_seed else False) # iter = folds.split(np.zeros(len(dataset)), dataset.data.y.numpy(), groups=dataset.data.subject_id) # 5 fold = 0 iter = multi_site_cv_split(dataset.data.y, dataset.data.site_id, dataset.data.subject_id, n_splits, random_state=fold_seed, shuffle=True if fold_seed else False) for train_idx, val_idx in tqdm_notebook(iter, desc='CV', leave=False): fold += 1 liveloss = PlotLosses() if live_loss else None # for a specific fold if fold_no is not None: if fold != fold_no: continue writer = SummaryWriter(log_dir=osp.join('runs', log_dir_base + str(fold))) model_save_dir = osp.join('saved_models', log_dir_base + str(fold)) print("creating dataloader tor fold {}".format(fold)) train_dataset, val_dataset = norm_train_val(dataset, train_idx, val_idx) model = model_cls(writer) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=lambda data_list: data_list, num_workers=num_workers, pin_memory=pin_memory) val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size, collate_fn=lambda data_list: data_list, num_workers=num_workers, pin_memory=pin_memory) if fold == 1 or fold_no is not None: print(model) writer.add_text('model_summary', model.__repr__()) writer.add_text('training_args', str(saved_args)) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False) # scheduler_reduce = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=5) # scheduler = scheduler_reduce # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) if dp and use_gpu: model = model.cuda() if device_ids is None else model.to( device_ids[0]) model = DataParallel(model, device_ids=device_ids) elif use_gpu: model = model.to(device) if saved_model_path is not None: model.load_state_dict(torch.load(saved_model_path)) best_map, patience_counter, best_score = 0.0, 0, np.inf for epoch in tqdm_notebook(range(1, num_epochs + 1), desc='Epoch', leave=False): logs = {} # scheduler.step(epoch=epoch, metrics=best_score) for phase in ['train', 'validation']: if phase == 'train': model.train() dataloader = train_dataloader else: model.eval() dataloader = val_dataloader # Logging running_total_loss = 0.0 running_corrects = 0 running_reg_loss = 0.0 running_nll_loss = 0.0 epoch_yhat_0, epoch_yhat_1 = torch.tensor([]), torch.tensor([]) epoch_label, epoch_predicted = torch.tensor([]), torch.tensor( []) logging_hist = True if phase == 'train' else False # once per epoch for data_list in tqdm_notebook(dataloader, desc=phase, leave=False): # TODO: check devices if dp: data_list = to_cuda(data_list, (device_ids[0] if device_ids is not None else 'cuda')) y_hat, domain_yhat, reg = model(data_list) y = torch.tensor([], dtype=dataset.data.y.dtype, device=device) domain_y = torch.tensor([], dtype=dataset.data.site_id.dtype, device=device) for data in data_list: y = torch.cat([y, data.y.view(-1).to(device)]) domain_y = torch.cat( [domain_y, data.site_id.view(-1).to(device)]) loss = criterion(y_hat, y) domain_loss = criterion(domain_yhat, domain_y) # domain_loss = -1e-7 * domain_loss # print(domain_loss.item()) if domain_cls: total_loss = domain_loss _, predicted = torch.max(domain_yhat, 1) label = domain_y if final_cls: total_loss = loss _, predicted = torch.max(y_hat, 1) label = y if domain_cls and final_cls: total_loss = (loss + domain_loss).sum() _, predicted = torch.max(y_hat, 1) label = y if is_reg: total_loss += reg.sum() if phase == 'train': # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True)) optimizer.zero_grad() total_loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 2.0) optimizer.step() running_nll_loss += loss.item() running_total_loss += total_loss.item() running_reg_loss += reg.sum().item() running_corrects += (predicted == label).sum().item() epoch_yhat_0 = torch.cat( [epoch_yhat_0, y_hat[:, 0].detach().view(-1).cpu()]) epoch_yhat_1 = torch.cat( [epoch_yhat_1, y_hat[:, 1].detach().view(-1).cpu()]) epoch_label = torch.cat( [epoch_label, label.detach().float().view(-1).cpu()]) epoch_predicted = torch.cat([ epoch_predicted, predicted.detach().float().view(-1).cpu() ]) # precision = sklearn.metrics.precision_score(epoch_label, epoch_predicted, average='micro') # recall = sklearn.metrics.recall_score(epoch_label, epoch_predicted, average='micro') # f1_score = sklearn.metrics.f1_score(epoch_label, epoch_predicted, average='micro') accuracy = sklearn.metrics.accuracy_score( epoch_label, epoch_predicted) epoch_total_loss = running_total_loss / dataloader.__len__() epoch_nll_loss = running_nll_loss / dataloader.__len__() epoch_reg_loss = running_reg_loss / dataloader.__len__() # print('epoch {} {}_nll_loss: {}'.format(epoch, phase, epoch_nll_loss)) writer.add_scalars( 'nll_loss', {'{}_nll_loss'.format(phase): epoch_nll_loss}, epoch) writer.add_scalars('accuracy', {'{}_accuracy'.format(phase): accuracy}, epoch) # writer.add_scalars('{}_APRF'.format(phase), # { # 'accuracy': accuracy, # 'precision': precision, # 'recall': recall, # 'f1_score': f1_score # }, # epoch) if epoch_reg_loss != 0: writer.add_scalars( 'reg_loss'.format(phase), {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch) # print(epoch_reg_loss) # writer.add_histogram('hist/{}_yhat_0'.format(phase), # epoch_yhat_0, # epoch) # writer.add_histogram('hist/{}_yhat_1'.format(phase), # epoch_yhat_1, # epoch) # Save Model & Early Stopping if phase == 'validation': model_save_path = model_save_dir + '-{}-{}-{:.3f}-{:.3f}'.format( model_name, epoch, accuracy, epoch_nll_loss) # best score if accuracy > best_map: best_map = accuracy model_save_path = model_save_path + '-best' score = epoch_nll_loss if score < best_score: patience_counter = 0 best_score = score else: patience_counter += 1 # skip first 10 epoch # best_score = best_score if epoch > 10 else -np.inf if save_model: for th, pfix in zip( [0.8, 0.75, 0.7, 0.5, 0.0], ['-perfect', '-great', '-good', '-bad', '-miss']): if accuracy >= th: model_save_path += pfix break torch.save(model.state_dict(), model_save_path) writer.add_scalars('best_val_accuracy', {'{}_accuracy'.format(phase): best_map}, epoch) writer.add_scalars( 'best_nll_loss', {'{}_nll_loss'.format(phase): best_score}, epoch) writer.add_scalars('learning_rate', { 'learning_rate': scheduler.optimizer.param_groups[0]['lr'] }, epoch) if patience_counter >= patience: print("Stopped at epoch {}".format(epoch)) return if live_loss: prefix = '' if phase == 'validation': prefix = 'val_' logs[prefix + 'log loss'] = epoch_nll_loss logs[prefix + 'accuracy'] = accuracy if live_loss: liveloss.update(logs) liveloss.draw() print("Done !")
def main(args): batch_size = args.batch_size model_fname = args.mod_name if multi_gpu and batch_size < torch.cuda.device_count(): exit('Batch size too small') # make a folder for the graphs of this model Path(args.output_dir).mkdir(exist_ok=True) save_dir = osp.join(args.output_dir, model_fname) Path(save_dir).mkdir(exist_ok=True) # get dataset and split gdata = GraphDataset(root=args.input_dir, bb=args.box_num) # merge data from separate files into one contiguous array bag = [] for g in gdata: bag += g random.Random(0).shuffle(bag) bag = bag[:args.num_data] # temporary patch to use px, py, pz for d in bag: d.x = d.x[:, :3] # 80:10:10 split datasets fulllen = len(bag) train_len = int(0.8 * fulllen) tv_len = int(0.10 * fulllen) train_dataset = bag[:train_len] valid_dataset = bag[train_len:train_len + tv_len] test_dataset = bag[train_len + tv_len:] train_samples = len(train_dataset) valid_samples = len(valid_dataset) test_samples = len(test_dataset) if multi_gpu: train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True) valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False) test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False) else: train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False) # specify loss function loss_ftn_obj = LossFunction(args.loss, emd_modname=args.emd_model_name, device=device) # create model input_dim = 3 big_dim = 32 hidden_dim = args.lat_dim lr = args.lr patience = args.patience if args.model == 'MetaLayerGAE': model = models.GNNAutoEncoder() else: if args.model[-3:] == 'EMD': model = getattr(models, args.model)(input_dim=input_dim, big_dim=big_dim, hidden_dim=hidden_dim, emd_modname=args.emd_model_name) else: model = getattr(models, args.model)(input_dim=input_dim, big_dim=big_dim, hidden_dim=hidden_dim) optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4) valid_losses = [] train_losses = [] start_epoch = 0 n_epochs = 200 # load in model modpath = osp.join(save_dir, model_fname + '.best.pth') try: model.load_state_dict(torch.load(modpath)) train_losses, valid_losses, start_epoch = torch.load( osp.join(save_dir, 'losses.pt')) print('Loaded model') best_valid_loss = test(model, valid_loader, valid_samples, batch_size, loss_ftn_obj) print(f'Saved model valid loss: {best_valid_loss}') except: print('Creating new model') best_valid_loss = 9999999 if multi_gpu: model = DataParallel(model) model.to(torch.device(device)) # Training loop stale_epochs = 0 loss = best_valid_loss for epoch in range(start_epoch, n_epochs): if multi_gpu: loss = train_parallel(model, optimizer, train_loader, train_samples, batch_size, loss_ftn_obj) valid_loss = test_parallel(model, valid_loader, valid_samples, batch_size, loss_ftn_obj) else: loss = train(model, optimizer, train_loader, train_samples, batch_size, loss_ftn_obj) valid_loss = test(model, valid_loader, valid_samples, batch_size, loss_ftn_obj) scheduler.step(valid_loss) train_losses.append(loss) valid_losses.append(valid_loss) print('Epoch: {:02d}, Training Loss: {:.4f}'.format(epoch, loss)) print(' Validation Loss: {:.4f}'.format(valid_loss)) if valid_loss < best_valid_loss: best_valid_loss = valid_loss print('New best model saved to:', modpath) if multi_gpu: torch.save(model.module.state_dict(), modpath) else: torch.save(model.state_dict(), modpath) torch.save((train_losses, valid_losses, epoch + 1), osp.join(save_dir, 'losses.pt')) stale_epochs = 0 else: stale_epochs += 1 print( f'Stale epoch: {stale_epochs}\nBest: {best_valid_loss}\nCurr: {valid_loss}' ) if stale_epochs >= patience: print('Early stopping after %i stale epochs' % patience) break # model training done train_epochs = list(range(epoch + 1)) early_stop_epoch = epoch - stale_epochs loss_curves(train_epochs, early_stop_epoch, train_losses, valid_losses, save_dir) # compare input and reconstructions model.load_state_dict(torch.load(modpath)) input_fts = [] reco_fts = [] for t in valid_loader: model.eval() if isinstance(t, list): for d in t: input_fts.append(d.x) else: input_fts.append(t.x) t.to(device) reco_out = model(t) if isinstance(reco_out, tuple): reco_out = reco_out[0] reco_fts.append(reco_out.cpu().detach()) input_fts = torch.cat(input_fts) reco_fts = torch.cat(reco_fts) plot_reco_difference( input_fts, reco_fts, model_fname, osp.join(save_dir, 'reconstruction_post_train', 'valid')) input_fts = [] reco_fts = [] for t in test_loader: model.eval() if isinstance(t, list): for d in t: input_fts.append(d.x) else: input_fts.append(t.x) t.to(device) reco_out = model(t) if isinstance(reco_out, tuple): reco_out = reco_out[0] reco_fts.append(reco_out.cpu().detach()) input_fts = torch.cat(input_fts) reco_fts = torch.cat(reco_fts) plot_reco_difference( input_fts, reco_fts, model_fname, osp.join(save_dir, 'reconstruction_post_train', 'test')) print('Completed')
args = parser.parse_args() # Load dataset dataset = PygNodePropPredDataset(name=args.dataset_name) data = dataset[0] dataset_test(data) if args.multi_gpu: # Unit test: GPU number verification # Prepare model device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = parse_model_name(args.model, dataset) model = DataParallel(model) model = model.to(device) #Split graph into subgraphs if args.subgraph_scheme == 'cluster': # Split data into subgraphs using cluster methods data_list = list(ClusterData(data, num_parts=args.num_parts)) elif args.subgraph_scheme == 'neighbor': data_list = list( NeighborSubgraphLoader(data, batch_size=args.neighbor_batch_size)) print( f'Using neighbor sampling | number of subgraphs: {len(data_list)}' ) # Run the model for each batch size setups batch_sizes = np.array(list(range(1, 65))) * 4
def train(): # set the input channel dims based on featurization type if args.feature_type == "pybel": feature_size = 20 else: feature_size = 75 print("found {} datasets in input train-data".format(len(args.train_data))) train_dataset_list = [] val_dataset_list = [] for data in args.train_data: train_dataset_list.append( PDBBindDataset( data_file=data, dataset_name=args.dataset_name, feature_type=args.feature_type, preprocessing_type=args.preprocessing_type, output_info=True, use_docking=args.use_docking, )) for data in args.val_data: val_dataset_list.append( PDBBindDataset( data_file=data, dataset_name=args.dataset_name, feature_type=args.feature_type, preprocessing_type=args.preprocessing_type, output_info=True, use_docking=args.use_docking, )) train_dataset = ConcatDataset(train_dataset_list) val_dataset = ConcatDataset(val_dataset_list) train_dataloader = DataListLoader( train_dataset, batch_size=args.batch_size, shuffle=False, worker_init_fn=worker_init_fn, drop_last=True, ) # just to keep batch sizes even, since shuffling is used val_dataloader = DataListLoader( val_dataset, batch_size=args.batch_size, shuffle=False, worker_init_fn=worker_init_fn, drop_last=True, ) tqdm.write("{} complexes in training dataset".format(len(train_dataset))) tqdm.write("{} complexes in validation dataset".format(len(val_dataset))) model = GeometricDataParallel( PotentialNetParallel( in_channels=feature_size, out_channels=1, covalent_gather_width=args.covalent_gather_width, non_covalent_gather_width=args.non_covalent_gather_width, covalent_k=args.covalent_k, non_covalent_k=args.non_covalent_k, covalent_neighbor_threshold=args.covalent_threshold, non_covalent_neighbor_threshold=args.non_covalent_threshold, )).float() model.train() model.to(0) tqdm.write(str(model)) tqdm.write("{} trainable parameters.".format( sum(p.numel() for p in model.parameters() if p.requires_grad))) tqdm.write("{} total parameters.".format( sum(p.numel() for p in model.parameters()))) criterion = nn.MSELoss().float() optimizer = Adam(model.parameters(), lr=args.lr) best_checkpoint_dict = None best_checkpoint_epoch = 0 best_checkpoint_step = 0 best_checkpoint_r2 = -9e9 step = 0 for epoch in range(args.epochs): losses = [] for batch in tqdm(train_dataloader): batch = [x for x in batch if x is not None] if len(batch) < 1: print("empty batch, skipping to next batch") continue optimizer.zero_grad() data = [x[2] for x in batch] y_ = model(data) y = torch.cat([x[2].y for x in batch]) loss = criterion(y.float(), y_.cpu().float()) losses.append(loss.cpu().data.item()) loss.backward() y_true = y.cpu().data.numpy() y_pred = y_.cpu().data.numpy() r2 = r2_score(y_true=y_true, y_pred=y_pred) mae = mean_absolute_error(y_true=y_true, y_pred=y_pred) pearsonr = stats.pearsonr(y_true.reshape(-1), y_pred.reshape(-1)) spearmanr = stats.spearmanr(y_true.reshape(-1), y_pred.reshape(-1)) tqdm.write( "epoch: {}\tloss:{:0.4f}\tr2: {:0.4f}\t pearsonr: {:0.4f}\tspearmanr: {:0.4f}\tmae: {:0.4f}\tpred stdev: {:0.4f}" "\t pred mean: {:0.4f} \tcovalent_threshold: {:0.4f} \tnon covalent threshold: {:0.4f}" .format( epoch, loss.cpu().data.numpy(), r2, float(pearsonr[0]), float(spearmanr[0]), float(mae), np.std(y_pred), np.mean(y_pred), model.module.covalent_neighbor_threshold.t.cpu().data.item( ), model.module.non_covalent_neighbor_threshold.t.cpu().data. item(), )) if args.checkpoint: if step % args.checkpoint_iter == 0: checkpoint_dict = checkpoint_model( model, val_dataloader, epoch, step, args.checkpoint_dir + "/model-epoch-{}-step-{}.pth".format(epoch, step), ) if checkpoint_dict["validate_dict"][ "r2"] > best_checkpoint_r2: best_checkpoint_step = step best_checkpoint_epoch = epoch best_checkpoint_r2 = checkpoint_dict["validate_dict"][ "r2"] best_checkpoint_dict = checkpoint_dict optimizer.step() step += 1 if args.checkpoint: checkpoint_dict = checkpoint_model( model, val_dataloader, epoch, step, args.checkpoint_dir + "/model-epoch-{}-step-{}.pth".format(epoch, step), ) if checkpoint_dict["validate_dict"]["r2"] > best_checkpoint_r2: best_checkpoint_step = step best_checkpoint_epoch = epoch best_checkpoint_r2 = checkpoint_dict["validate_dict"]["r2"] best_checkpoint_dict = checkpoint_dict if args.checkpoint: # once broken out of the loop, save last model checkpoint_dict = checkpoint_model( model, val_dataloader, epoch, step, args.checkpoint_dir + "/model-epoch-{}-step-{}.pth".format(epoch, step), ) if checkpoint_dict["validate_dict"]["r2"] > best_checkpoint_r2: best_checkpoint_step = step best_checkpoint_epoch = epoch best_checkpoint_r2 = checkpoint_dict["validate_dict"]["r2"] best_checkpoint_dict = checkpoint_dict if args.checkpoint: torch.save(best_checkpoint_dict, args.checkpoint_dir + "/best_checkpoint.pth") print("best training checkpoint epoch {}/step {} with r2: {}".format( best_checkpoint_epoch, best_checkpoint_step, best_checkpoint_r2))
def train_cross_validation(model_cls, dataset, num_clusters, dropout=0.0, lr=1e-4, weight_decay=1e-2, num_epochs=200, n_splits=10, use_gpu=True, dp=False, ddp=True, comment='', tb_service_loc='192.168.192.57:6006', batch_size=1, num_workers=0, pin_memory=False, cuda_device=None, fold_no=None, saved_model_path=None, device_ids=None, patience=50, seed=None, save_model=True, c_reg=0, base_log_dir='runs', base_model_save_dir='saved_models'): """ :param c_reg: :param save_model: bool :param seed: :param patience: for early stopping :param device_ids: for ddp :param saved_model_path: :param fold_no: :param ddp: DDP :param cuda_device: :param pin_memory: DataLoader args https://devblogs.nvidia.com/how-optimize-data-transfers-cuda-cc/ :param num_workers: DataLoader args :param model_cls: pytorch Module cls :param dataset: pytorch Dataset cls :param dropout: :param lr: :param weight_decay: :param num_epochs: :param n_splits: number of kFolds :param use_gpu: bool :param dp: bool :param comment: comment in the logs, to filter runs in tensorboard :param tb_service_loc: tensorboard service location :param batch_size: Dataset args not DataLoader :return: """ saved_args = locals() seed = int(time.time() % 1e4 * 1e5) if seed is None else seed saved_args['random_seed'] = seed torch.manual_seed(seed) np.random.seed(seed) if use_gpu: torch.cuda.manual_seed_all(seed) if ddp and not torch.distributed.is_initialized(): # initialize ddp dist.init_process_group('nccl', init_method='tcp://localhost:{}'.format( find_open_port()), world_size=1, rank=0) model_name = model_cls.__name__ if not cuda_device: if device_ids and (ddp or dp): device = device_ids[0] else: device = torch.device( 'cuda' if torch.cuda.is_available() and use_gpu else 'cpu') else: device = cuda_device device_count = torch.cuda.device_count() if dp else 1 device_count = len(device_ids) if (device_ids is not None and (dp or ddp)) else device_count if device_count > 1: print("Let's use", device_count, "GPUs!") # batch_size = batch_size * device_count log_dir_base = get_model_log_dir(comment, model_name) if tb_service_loc is not None: print("TensorBoard available at http://{1}/#scalars®exInput={0}". format(log_dir_base, tb_service_loc)) else: print("Please set up TensorBoard") criterion = nn.CrossEntropyLoss() # get test set folds = StratifiedKFold(n_splits=n_splits, shuffle=False) train_val_idx, test_idx = list( folds.split(np.zeros(len(dataset)), dataset.data.y.numpy()))[0] test_dataset = dataset.__indexing__(test_idx) train_val_dataset = dataset.__indexing__(train_val_idx) print("Training {0} {1} models for cross validation...".format( n_splits, model_name)) # folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0 folds = StratifiedKFold(n_splits=n_splits, shuffle=False) iter = folds.split(np.zeros(len(train_val_dataset)), train_val_dataset.data.y.numpy()) fold = 0 for train_idx, val_idx in tqdm_notebook(iter, desc='CV', leave=False): fold += 1 if fold_no is not None: if fold != fold_no: continue writer = SummaryWriter(log_dir=osp.join(base_log_dir, log_dir_base + str(fold))) model_save_dir = osp.join(base_model_save_dir, log_dir_base + str(fold)) print("creating dataloader tor fold {}".format(fold)) model = model_cls(writer, num_clusters=num_clusters, in_dim=dataset.data.x.shape[1], out_dim=int(dataset.data.y.max() + 1), dropout=dropout) # My Batch train_dataset = train_val_dataset.__indexing__(train_idx) val_dataset = train_val_dataset.__indexing__(val_idx) train_dataset = dataset_gather( train_dataset, seed=0, n_repeat=1, n_splits=int(len(train_dataset) / batch_size) + 1) val_dataset = dataset_gather( val_dataset, seed=0, n_repeat=1, n_splits=int(len(val_dataset) / batch_size) + 1) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=device_count, collate_fn=lambda data_list: data_list, num_workers=num_workers, pin_memory=pin_memory) val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=device_count, collate_fn=lambda data_list: data_list, num_workers=num_workers, pin_memory=pin_memory) # if fold == 1 or fold_no is not None: print(model) writer.add_text('model_summary', model.__repr__()) writer.add_text('training_args', str(saved_args)) optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) if ddp: model = model.cuda() if device_ids is None else model.to( device_ids[0]) model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) elif dp and use_gpu: model = model.cuda() if device_ids is None else model.to( device_ids[0]) model = DataParallel(model, device_ids=device_ids) elif use_gpu: model = model.to(device) if saved_model_path is not None: model.load_state_dict(torch.load(saved_model_path)) best_map, patience_counter, best_score = 0.0, 0, -np.inf for epoch in tqdm_notebook(range(1, num_epochs + 1), desc='Epoch', leave=False): for phase in ['train', 'validation']: if phase == 'train': model.train() dataloader = train_dataloader else: model.eval() dataloader = val_dataloader # Logging running_total_loss = 0.0 running_corrects = 0 running_reg_loss = 0.0 running_nll_loss = 0.0 epoch_yhat_0, epoch_yhat_1 = torch.tensor([]), torch.tensor([]) epoch_label, epoch_predicted = torch.tensor([]), torch.tensor( []) for data_list in tqdm_notebook(dataloader, desc=phase, leave=False): # TODO: check devices if dp: data_list = to_cuda(data_list, (device_ids[0] if device_ids is not None else 'cuda')) y_hat, reg = model(data_list) # y_hat = y_hat.reshape(batch_size, -1) y = torch.tensor([], dtype=dataset.data.y.dtype, device=device) for data in data_list: y = torch.cat([y, data.y.view(-1).to(device)]) loss = criterion(y_hat, y) reg_loss = -reg total_loss = (loss + reg_loss * c_reg).sum() if phase == 'train': # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True)) optimizer.zero_grad() total_loss.backward(retain_graph=True) nn.utils.clip_grad_norm_(model.parameters(), 2.0) optimizer.step() _, predicted = torch.max(y_hat, 1) label = y running_nll_loss += loss.item() running_total_loss += total_loss.item() running_reg_loss += reg.sum().item() running_corrects += (predicted == label).sum().item() epoch_yhat_0 = torch.cat( [epoch_yhat_0, y_hat[:, 0].detach().view(-1).cpu()]) epoch_yhat_1 = torch.cat( [epoch_yhat_1, y_hat[:, 1].detach().view(-1).cpu()]) epoch_label = torch.cat( [epoch_label, label.detach().cpu().float()]) epoch_predicted = torch.cat( [epoch_predicted, predicted.detach().cpu().float()]) # precision = sklearn.metrics.precision_score(epoch_label, epoch_predicted, average='micro') # recall = sklearn.metrics.recall_score(epoch_label, epoch_predicted, average='micro') # f1_score = sklearn.metrics.f1_score(epoch_label, epoch_predicted, average='micro') accuracy = sklearn.metrics.accuracy_score( epoch_label, epoch_predicted) epoch_total_loss = running_total_loss / dataloader.__len__() epoch_nll_loss = running_nll_loss / dataloader.__len__() epoch_reg_loss = running_reg_loss / dataloader.dataset.__len__( ) writer.add_scalars( 'nll_loss', {'{}_nll_loss'.format(phase): epoch_nll_loss}, epoch) writer.add_scalars('accuracy', {'{}_accuracy'.format(phase): accuracy}, epoch) # writer.add_scalars('{}_APRF'.format(phase), # { # 'accuracy': accuracy, # 'precision': precision, # 'recall': recall, # 'f1_score': f1_score # }, # epoch) if epoch_reg_loss != 0: writer.add_scalars( 'reg_loss'.format(phase), {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch) # writer.add_histogram('hist/{}_yhat_0'.format(phase), # epoch_yhat_0, # epoch) # writer.add_histogram('hist/{}_yhat_1'.format(phase), # epoch_yhat_1, # epoch) # Save Model & Early Stopping if phase == 'validation': model_save_path = model_save_dir + '-{}-{}-{:.3f}-{:.3f}'.format( model_name, epoch, accuracy, epoch_nll_loss) if accuracy > best_map: best_map = accuracy model_save_path = model_save_path + '-best' score = -epoch_nll_loss if score > best_score: patience_counter = 0 best_score = score else: patience_counter += 1 # skip 10 epoch # best_score = best_score if epoch > 10 else -np.inf if save_model: for th, pfix in zip( [0.8, 0.75, 0.7, 0.5, 0.0], ['-perfect', '-great', '-good', '-bad', '-miss']): if accuracy >= th: model_save_path += pfix break if epoch > 10: torch.save(model.state_dict(), model_save_path) writer.add_scalars('best_val_accuracy', {'{}_accuracy'.format(phase): best_map}, epoch) writer.add_scalars( 'best_nll_loss', {'{}_nll_loss'.format(phase): -best_score}, epoch) if patience_counter >= patience: print("Stopped at epoch {}".format(epoch)) return print("Done !")
def train_cummunity_detection(model_cls, dataset, dropout=0.0, lr=1e-3, weight_decay=1e-2, num_epochs=200, n_splits=10, use_gpu=True, dp=False, ddp=False, comment='', tb_service_loc='192.168.192.57:6006', batch_size=1, num_workers=0, pin_memory=False, cuda_device=None, ddp_port='23456', fold_no=None, device_ids=None, patience=20, seed=None, save_model=False, supervised=False): """ :param save_model: bool :param seed: :param patience: for early stopping :param device_ids: for ddp :param saved_model_path: :param fold_no: :param ddp_port: :param ddp: DDP :param cuda_device: :param pin_memory: DataLoader args https://devblogs.nvidia.com/how-optimize-data-transfers-cuda-cc/ :param num_workers: DataLoader args :param model_cls: pytorch Module cls :param dataset: pytorch Dataset cls :param dropout: :param lr: :param weight_decay: :param num_epochs: :param n_splits: number of kFolds :param use_gpu: bool :param dp: bool :param comment: comment in the logs, to filter runs in tensorboard :param tb_service_loc: tensorboard service location :param batch_size: Dataset args not DataLoader :return: """ saved_args = locals() seed = int(time.time() % 1e4 * 1e5) if seed is None else seed saved_args['random_seed'] = seed torch.manual_seed(seed) np.random.seed(seed) if use_gpu: torch.cuda.manual_seed_all(seed) if ddp and not torch.distributed.is_initialized(): # initialize ddp dist.init_process_group( 'nccl', init_method='tcp://localhost:{}'.format(ddp_port), world_size=1, rank=0) model_name = model_cls.__name__ if not cuda_device: if device_ids and (ddp or dp): device = device_ids[0] else: device = torch.device( 'cuda' if torch.cuda.is_available() and use_gpu else 'cpu') else: device = cuda_device device_count = torch.cuda.device_count() if dp else 1 device_count = len(device_ids) if (device_ids is not None and (dp or ddp)) else device_count if device_count > 1: print("Let's use", device_count, "GPUs!") # batch_size = batch_size * device_count log_dir_base = get_model_log_dir(comment, model_name) if tb_service_loc is not None: print("TensorBoard available at http://{1}/#scalars®exInput={0}". format(log_dir_base, tb_service_loc)) else: print("Please set up TensorBoard") print("Training {0} {1} models for cross validation...".format( n_splits, model_name)) folds, fold = KFold(n_splits=n_splits, shuffle=False, random_state=seed), 0 print(dataset.__len__()) for train_idx, test_idx in tqdm_notebook(folds.split( list(range(dataset.__len__())), list(range(dataset.__len__()))), desc='models', leave=False): fold += 1 if fold_no is not None: if fold != fold_no: continue writer = SummaryWriter(log_dir=osp.join('runs', log_dir_base + str(fold))) model_save_dir = osp.join('saved_models', log_dir_base + str(fold)) print("creating dataloader tor fold {}".format(fold)) model = model_cls(writer, dropout=dropout) # My Batch train_dataset = dataset.__indexing__(train_idx) test_dataset = dataset.__indexing__(test_idx) train_dataset = dataset_gather(train_dataset, n_repeat=1, n_splits=int( len(train_dataset) / batch_size)) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=device_count, collate_fn=lambda data_list: data_list, num_workers=num_workers, pin_memory=pin_memory) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=device_count, collate_fn=lambda data_list: data_list, num_workers=num_workers, pin_memory=pin_memory) print(model) writer.add_text('model_summary', model.__repr__()) writer.add_text('training_args', str(saved_args)) optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) if ddp: model = model.cuda() if device_ids is None else model.to( device_ids[0]) model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) elif dp and use_gpu: model = model.cuda() if device_ids is None else model.to( device_ids[0]) model = DataParallel(model, device_ids=device_ids) elif use_gpu: model = model.to(device) for epoch in tqdm_notebook(range(1, num_epochs + 1), desc='Epoch', leave=False): for phase in ['train', 'validation']: if phase == 'train': model.train() dataloader = train_dataloader else: model.eval() dataloader = test_dataloader # Logging running_total_loss = 0.0 running_reg_loss = 0.0 running_overlap = 0.0 for data_list in tqdm_notebook(dataloader, desc=phase, leave=False): # TODO: check devices if dp: data_list = to_cuda(data_list, (device_ids[0] if device_ids is not None else 'cuda')) y_hat, reg = model(data_list) y = torch.tensor([], dtype=dataset.data.y.dtype, device=device) for data in data_list: y = torch.cat([y, data.y.view(-1).to(device)]) if supervised: loss = permutation_invariant_loss(y_hat, y) # criterion = nn.NLLLoss() # loss = criterion(y_hat, y) else: loss = -reg total_loss = loss if phase == 'train': # print(torch.autograd.grad(y_hat.sum(), model.saved_x, retain_graph=True)) optimizer.zero_grad() total_loss.backward(retain_graph=True) nn.utils.clip_grad_norm_(model.parameters(), 2.0) optimizer.step() _, predicted = torch.max(y_hat, 1) label = y if supervised: overlap_score = normalized_overlap( label.int().cpu().numpy(), predicted.int().cpu().numpy(), 0.25) # overlap_score = overlap(label.int().cpu().numpy(), predicted.int().cpu().numpy()) running_overlap += overlap_score print(reg, overlap_score, loss) running_total_loss += total_loss.item() running_reg_loss += reg.sum().item() epoch_total_loss = running_total_loss / dataloader.__len__() epoch_reg_loss = running_reg_loss / dataloader.dataset.__len__( ) if supervised: epoch_overlap = running_overlap / dataloader.__len__() writer.add_scalars( 'overlap'.format(phase), {'{}_overlap'.format(phase): epoch_overlap}, epoch) writer.add_scalars( 'reg_loss'.format(phase), {'{}_reg_loss'.format(phase): epoch_reg_loss}, epoch) print("Done !")