def pdts(args: TrainArgs, model_idx): """ preliminary experiment with PDTS (approximate BO) we use a data set size of 50k and run until we have trained with 15k data points our batch size is 50 we initialise with 1000 data points """ ######## set up all logging ######## logger = None # make save_dir save_dir = os.path.join(args.save_dir, f'model_{model_idx}') makedirs(save_dir) # make results_dir results_dir = args.results_dir makedirs(results_dir) # initialise wandb #os.environ['WANDB_MODE'] = 'dryrun' wandb.init(name=args.wandb_name + '_' + str(model_idx), project=args.wandb_proj, reinit=True) #print('WANDB directory is:') #print(wandb.run.dir) #################################### ########## get data args.task_names = args.target_columns or get_task_names(args.data_path) data = get_data(path=args.data_path, args=args, logger=logger) args.num_tasks = data.num_tasks() args.features_size = data.features_size() ########## SMILES of top 1% top1p = np.array(MoleculeDataset(data).targets()) top1p_idx = np.argsort(-top1p[:, 0])[:int(args.max_data_size * 0.01)] SMILES = np.array(MoleculeDataset(data).smiles())[top1p_idx] ########## initial data splits args.seed = args.data_seeds[model_idx] data.shuffle(seed=args.seed) sizes = args.split_sizes train_size = int(sizes[0] * len(data)) train_orig = data[:train_size] test_orig = data[train_size:] train_data, test_data = copy.deepcopy( MoleculeDataset(train_orig)), copy.deepcopy(MoleculeDataset(test_orig)) args.train_data_size = len(train_data) ########## standardising # features (train and test) features_scaler = train_data.normalize_features(replace_nan_token=0) test_data.normalize_features(features_scaler) # targets (train) train_targets = train_data.targets() test_targets = test_data.targets() scaler = StandardScaler().fit(train_targets) scaled_targets = scaler.transform(train_targets).tolist() train_data.set_targets(scaled_targets) ########## loss, metric functions loss_func = neg_log_like metric_func = get_metric_func(metric=args.metric) ########## data loaders if len(data) <= args.cache_cutoff: cache = True num_workers = 0 else: cache = False num_workers = args.num_workers train_data_loader = MoleculeDataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=num_workers, cache=cache, class_balance=args.class_balance, shuffle=True, seed=args.seed) test_data_loader = MoleculeDataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=num_workers, cache=cache) ########## instantiating model, optimiser, scheduler (MAP) # set pytorch seed for random initial weights torch.manual_seed(args.pytorch_seeds[model_idx]) # build model print(f'Building model {model_idx}') model = MoleculeModel(args) print(model) print(f'Number of parameters = {param_count(model):,}') if args.cuda: print('Moving model to cuda') model = model.to(args.device) # optimizer optimizer = Adam([{ 'params': model.encoder.parameters() }, { 'params': model.ffn.parameters() }, { 'params': model.log_noise, 'weight_decay': 0 }], lr=args.lr, weight_decay=args.weight_decay) # learning rate scheduler scheduler = scheduler_const([args.lr]) #################################################################### #################################################################### # FIRST THOMPSON ITERATION ### scores array ptds_scores = np.ones(args.pdts_batches + 1) batch_no = 0 ### fill for batch 0 SMILES_train = np.array(train_data.smiles()) SMILES_stack = np.hstack((SMILES, SMILES_train)) overlap = len(SMILES_stack) - len(np.unique(SMILES_stack)) prop = overlap / len(SMILES) ptds_scores[batch_no] = prop wandb.log({ "Proportion of top 1%": prop, "batch_no": batch_no }, commit=False) ### train MAP posterior gp_switch = False likelihood = None bbp_switch = None n_iter = 0 for epoch in range(args.epochs_init_map): n_iter = train(model=model, data_loader=train_data_loader, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, bbp_switch=bbp_switch) # save to save_dir #if epoch == args.epochs_init_map - 1: #save_checkpoint(os.path.join(save_dir, f'model_{batch_no}.pt'), model, scaler, features_scaler, args) # if X load from checkpoint path if args.bbp or args.gp or args.swag or args.sgld: model = load_checkpoint(args.checkpoint_path + f'/model_{model_idx}/model_{batch_no}.pt', device=args.device, logger=None) ########## BBP if args.bbp: model_bbp = MoleculeModelBBP( args) # instantiate with bayesian linear layers for (_, param_bbp), (_, param_pre) in zip(model_bbp.named_parameters(), model.named_parameters()): param_bbp.data = copy.deepcopy( param_pre.data.T) # copy over parameters # instantiate rhos for layer in model_bbp.children(): if isinstance(layer, BayesLinear): layer.init_rho(args.rho_min_bbp, args.rho_max_bbp) for layer in model_bbp.encoder.encoder.children(): if isinstance(layer, BayesLinear): layer.init_rho(args.rho_min_bbp, args.rho_max_bbp) model = model_bbp # name back # move to cuda if args.cuda: print('Moving bbp model to cuda') model = model.to(args.device) # optimiser and scheduler optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = scheduler_const([args.lr]) bbp_switch = 2 n_iter = 0 for epoch in range(args.epochs_init): n_iter = train(model=model, data_loader=train_data_loader, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, bbp_switch=bbp_switch) ########## GP if args.gp: # feature_extractor model.featurizer = True feature_extractor = model # inducing points inducing_points = initial_inducing_points(train_data_loader, feature_extractor, args) # GP layer gp_layer = GPLayer(inducing_points, args.num_tasks) # full DKL model model = copy.deepcopy(DKLMoleculeModel(feature_extractor, gp_layer)) # likelihood (rank 0 restricts to diagonal matrix) likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood( num_tasks=12, rank=0) # model and likelihood to CUDA if args.cuda: model.cuda() likelihood.cuda() # loss object loss_func = gpytorch.mlls.VariationalELBO( likelihood, model.gp_layer, num_data=args.train_data_size) # optimiser and scheduler params_list = [ { 'params': model.feature_extractor.parameters(), 'weight_decay': args.weight_decay_gp }, { 'params': model.gp_layer.hyperparameters() }, { 'params': model.gp_layer.variational_parameters() }, { 'params': likelihood.parameters() }, ] optimizer = torch.optim.Adam(params_list, lr=args.lr) scheduler = scheduler_const([args.lr]) gp_switch = True n_iter = 0 for epoch in range(args.epochs_init): n_iter = train(model=model, data_loader=train_data_loader, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, gp_switch=gp_switch, likelihood=likelihood) ########## SWAG if args.swag: model_core = copy.deepcopy(model) model = train_swag_pdts(model_core, train_data_loader, loss_func, scaler, features_scaler, args, save_dir, batch_no) ########## SGLD if args.sgld: model = train_sgld_pdts(model, train_data_loader, loss_func, scaler, features_scaler, args, save_dir, batch_no) ### find top_idx top_idx = [] # need for thom sum_test_preds = np.zeros( (len(test_orig), args.num_tasks)) # need for greedy for sample in range(args.samples): # draw model from SWAG posterior if args.swag: model.sample(scale=1.0, cov=args.cov_mat, block=args.block) # retrieve sgld sample if args.sgld: model = load_checkpoint( args.save_dir + f'/model_{model_idx}/model_{batch_no}/model_{sample}.pt', device=args.device, logger=logger) test_preds = predict(model=model, data_loader=test_data_loader, args=args, scaler=scaler, test_data=True, gp_sample=args.thompson, bbp_sample=True) test_preds = np.array(test_preds) # thompson bit rank = 0 # base length if args.sgld: base_length = 5 * sample + 4 else: base_length = sample while args.thompson and (len(top_idx) <= base_length): top_unique_molecule = np.argsort(-test_preds[:, 0])[rank] rank += 1 if top_unique_molecule not in top_idx: top_idx.append(top_unique_molecule) # add to sum_test_preds sum_test_preds += test_preds # print print('done sample ' + str(sample)) # final top_idx if args.thompson: top_idx = np.array(top_idx) else: sum_test_preds /= args.samples top_idx = np.argsort(-sum_test_preds[:, 0])[:50] ### transfer from test to train top_idx = -np.sort(-top_idx) for idx in top_idx: train_orig.append(test_orig.pop(idx)) train_data, test_data = copy.deepcopy( MoleculeDataset(train_orig)), copy.deepcopy(MoleculeDataset(test_orig)) args.train_data_size = len(train_data) if args.gp: loss_func = gpytorch.mlls.VariationalELBO( likelihood, model.gp_layer, num_data=args.train_data_size) print(args.train_data_size) ### standardise features (train and test; using original features_scaler) train_data.normalize_features(features_scaler) test_data.normalize_features(features_scaler) ### standardise targets (train only; using original scaler) train_targets = train_data.targets() scaled_targets_tr = scaler.transform(train_targets).tolist() train_data.set_targets(scaled_targets_tr) ### create data loaders train_data_loader = MoleculeDataLoader(dataset=train_data, batch_size=args.batch_size, num_workers=num_workers, cache=cache, class_balance=args.class_balance, shuffle=True, seed=args.seed) test_data_loader = MoleculeDataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=num_workers, cache=cache) #################################################################### #################################################################### ################################## ########## thompson sampling loop ################################## for batch_no in range(1, args.pdts_batches + 1): ### fill in ptds_scores SMILES_train = np.array(train_data.smiles()) SMILES_stack = np.hstack((SMILES, SMILES_train)) overlap = len(SMILES_stack) - len(np.unique(SMILES_stack)) prop = overlap / len(SMILES) ptds_scores[batch_no] = prop wandb.log({ "Proportion of top 1%": prop, "batch_no": batch_no }, commit=False) ### train posterior n_iter = 0 for epoch in range(args.epochs): n_iter = train(model=model, data_loader=train_data_loader, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, args=args, n_iter=n_iter, gp_switch=gp_switch, likelihood=likelihood, bbp_switch=bbp_switch) # save to save_dir #if epoch == args.epochs - 1: #save_checkpoint(os.path.join(save_dir, f'model_{batch_no}.pt'), model, scaler, features_scaler, args) # if swag, load checkpoint if args.swag: model_core = load_checkpoint( args.checkpoint_path + f'/model_{model_idx}/model_{batch_no}.pt', device=args.device, logger=None) ########## SWAG if args.swag: model = train_swag_pdts(model_core, train_data_loader, loss_func, scaler, features_scaler, args, save_dir, batch_no) ########## SGLD if args.sgld: model = train_sgld_pdts(model, train_data_loader, loss_func, scaler, features_scaler, args, save_dir, batch_no) ### find top_idx top_idx = [] # need for thom sum_test_preds = np.zeros( (len(test_orig), args.num_tasks)) # need for greedy for sample in range(args.samples): # draw model from SWAG posterior if args.swag: model.sample(scale=1.0, cov=args.cov_mat, block=args.block) # retrieve sgld sample if args.sgld: model = load_checkpoint( args.save_dir + f'/model_{model_idx}/model_{batch_no}/model_{sample}.pt', device=args.device, logger=logger) test_preds = predict(model=model, data_loader=test_data_loader, args=args, scaler=scaler, test_data=True, gp_sample=args.thompson, bbp_sample=True) test_preds = np.array(test_preds) # thompson bit rank = 0 # base length if args.sgld: base_length = 5 * sample + 4 else: base_length = sample while args.thompson and (len(top_idx) <= base_length): top_unique_molecule = np.argsort(-test_preds[:, 0])[rank] rank += 1 if top_unique_molecule not in top_idx: top_idx.append(top_unique_molecule) # add to sum_test_preds sum_test_preds += test_preds # print print('done sample ' + str(sample)) # final top_idx if args.thompson: top_idx = np.array(top_idx) else: sum_test_preds /= args.samples top_idx = np.argsort(-sum_test_preds[:, 0])[:50] ### transfer from test to train top_idx = -np.sort(-top_idx) for idx in top_idx: train_orig.append(test_orig.pop(idx)) train_data, test_data = copy.deepcopy( MoleculeDataset(train_orig)), copy.deepcopy( MoleculeDataset(test_orig)) args.train_data_size = len(train_data) if args.gp: loss_func = gpytorch.mlls.VariationalELBO( likelihood, model.gp_layer, num_data=args.train_data_size) print(args.train_data_size) ### standardise features (train and test; using original features_scaler) train_data.normalize_features(features_scaler) test_data.normalize_features(features_scaler) ### standardise targets (train only; using original scaler) train_targets = train_data.targets() scaled_targets_tr = scaler.transform(train_targets).tolist() train_data.set_targets(scaled_targets_tr) ### create data loaders train_data_loader = MoleculeDataLoader( dataset=train_data, batch_size=args.batch_size, num_workers=num_workers, cache=cache, class_balance=args.class_balance, shuffle=True, seed=args.seed) test_data_loader = MoleculeDataLoader(dataset=test_data, batch_size=args.batch_size, num_workers=num_workers, cache=cache) # save scores np.savez(os.path.join(results_dir, f'ptds_{model_idx}'), ptds_scores)
def train(model: MoleculeModel, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`. :param loss_func: Loss function. :param optimizer: An optimizer. :param scheduler: A learning rate scheduler. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for recording output. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() loss_sum, iter_count = 0, 0 for batch in tqdm(data_loader, total=len(data_loader)): # Prepare batch batch: MoleculeDataset mol_batch, features_batch, target_batch = batch.batch_graph( ), batch.features(), batch.targets() mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch]) targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # Run model model.zero_grad() preds = model(mol_batch, features_batch) # Move tensors to correct device mask = mask.to(preds.device) targets = targets.to(preds.device) class_weights = torch.ones(targets.shape, device=preds.device) if args.dataset_type == 'multiclass': targets = targets.long() loss = torch.cat([ loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1)) ], dim=1) * class_weights * mask else: loss = loss_func(preds, targets) * class_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += len(batch) loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum, iter_count = 0, 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug( f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}' ) if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter
def train(model: MoleculeModel, data_loader: MoleculeDataLoader, loss_func: Callable, optimizer: Optimizer, scheduler: _LRScheduler, args: TrainArgs, n_iter: int = 0, logger: logging.Logger = None, writer: SummaryWriter = None) -> int: """ Trains a model for an epoch. :param model: A :class:`~chemprop.models.model.MoleculeModel`. :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`. :param loss_func: Loss function. :param optimizer: An optimizer. :param scheduler: A learning rate scheduler. :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model. :param n_iter: The number of iterations (training examples) trained on so far. :param logger: A logger for recording output. :param writer: A tensorboardX SummaryWriter. :return: The total number of iterations (training examples) trained on so far. """ debug = logger.debug if logger is not None else print model.train() loss_sum = iter_count = 0 for batch in tqdm(data_loader, total=len(data_loader), leave=False): # Prepare batch batch: MoleculeDataset mol_batch, features_batch, target_batch, mask_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch, data_weights_batch = \ batch.batch_graph(), batch.features(), batch.targets(), batch.mask(), batch.atom_descriptors(), \ batch.atom_features(), batch.bond_features(), batch.data_weights() mask = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks) targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks) if args.target_weights is not None: target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks) else: target_weights = torch.ones(targets.shape[1]).unsqueeze(0) data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1) if args.loss_function == 'bounded_mse': lt_target_batch = batch.lt_targets() # shape(batch, tasks) gt_target_batch = batch.gt_targets() # shape(batch, tasks) lt_target_batch = torch.tensor(lt_target_batch) gt_target_batch = torch.tensor(gt_target_batch) # Run model model.zero_grad() preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch) # Move tensors to correct device torch_device = preds.device mask = mask.to(torch_device) targets = targets.to(torch_device) target_weights = target_weights.to(torch_device) data_weights = data_weights.to(torch_device) if args.loss_function == 'bounded_mse': lt_target_batch = lt_target_batch.to(torch_device) gt_target_batch = gt_target_batch.to(torch_device) # Calculate losses if args.loss_function == 'mcc' and args.dataset_type == 'classification': loss = loss_func(preds, targets, data_weights, mask) *target_weights.squeeze(0) elif args.loss_function == 'mcc': # multiclass dataset type targets = targets.long() target_losses = [] for target_index in range(preds.size(1)): target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, mask[:, target_index]).unsqueeze(0) target_losses.append(target_loss) loss = torch.cat(target_losses).to(torch_device) * target_weights.squeeze(0) elif args.dataset_type == 'multiclass': targets = targets.long() if args.loss_function == 'dirichlet': loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask else: target_losses = [] for target_index in range(preds.size(1)): target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) target_losses.append(target_loss) loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * mask elif args.dataset_type == 'spectra': loss = loss_func(preds, targets, mask) * target_weights * data_weights * mask elif args.loss_function == 'bounded_mse': loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * mask elif args.loss_function == 'evidential': loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask elif args.loss_function == 'dirichlet': # classification loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask else: loss = loss_func(preds, targets) * target_weights * data_weights * mask loss = loss.sum() / mask.sum() loss_sum += loss.item() iter_count += 1 loss.backward() if args.grad_clip: nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if isinstance(scheduler, NoamLR): scheduler.step() n_iter += len(batch) # Log and/or add to tensorboard if (n_iter // args.batch_size) % args.log_frequency == 0: lrs = scheduler.get_lr() pnorm = compute_pnorm(model) gnorm = compute_gnorm(model) loss_avg = loss_sum / iter_count loss_sum = iter_count = 0 lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs)) debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}') if writer is not None: writer.add_scalar('train_loss', loss_avg, n_iter) writer.add_scalar('param_norm', pnorm, n_iter) writer.add_scalar('gradient_norm', gnorm, n_iter) for i, lr in enumerate(lrs): writer.add_scalar(f'learning_rate_{i}', lr, n_iter) return n_iter