def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, verbose: bool = True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}): """Function to abstract away training loop. The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of common training functionality provided they are written as a subclass of voicemap.Callback (following the Keras API). # Arguments model: Model to be fitted. optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs epochs: Number of epochs of fitting to be performed dataloader: `torch.DataLoader` instance to fit the model to prepare_batch: Callable to perform any desired preprocessing metrics: Optional list of metrics to evaluate the model with callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. verbose: All print output is muted if this argument is `False` fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled batches. For more complex training procedures (meta-learning etc...) you will need to write your own fit_function fit_function_kwargs: Keyword arguments to pass to `fit_function` """ # Determine number of samples: num_batches = len(dataloader) batch_size = dataloader.batch_size callbacks = CallbackList([ DefaultCallback(), ] + (callbacks or []) + [ ProgressBarLogger(), ]) callbacks.set_model(model) callbacks.set_params({ 'num_batches': num_batches, 'batch_size': batch_size, 'verbose': verbose, 'metrics': (metrics or []), 'prepare_batch': prepare_batch, 'loss_fn': loss_fn, 'optimiser': optimiser }) if verbose: print('Begin training...') callbacks.on_train_begin() for epoch in range(1, epochs + 1): callbacks.on_epoch_begin(epoch) epoch_logs = {} for batch_index, batch in enumerate(dataloader): batch_logs = dict(batch=batch_index, size=(batch_size or 1)) callbacks.on_batch_begin(batch_index, batch_logs) x, y = prepare_batch(batch) loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs) #print(.shape) batch_logs['loss'] = loss.item() # Loops through all metrics batch_logs = batch_metrics(model, y_pred, y[:, -1:, :], metrics, batch_logs) callbacks.on_batch_end(batch_index, batch_logs) # Run on epoch end callbacks.on_epoch_end(epoch, epoch_logs) # Run on train end if verbose: print('Finished.') callbacks.on_train_end()
def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, writer: SummaryWriter, prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, verbose: bool = True, fit_function: Callable = gradient_step, stnmodel=None, stnoptim=None, args=None, fit_function_kwargs: dict = {}): """Function to abstract away training loop. The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of common training functionality provided they are written as a subclass of voicemap.Callback (following the Keras API). # Arguments model: Model to be fitted. optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs epochs: Number of epochs of fitting to be performed dataloader: `torch.DataLoader` instance to fit the model to writer: `tensorboard.SummaryWriter` instance to write plots to tensorboard prepare_batch: Callable to perform any desired preprocessing metrics: Optional list of metrics to evaluate the model with callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. verbose: All print output is muted if this argument is `False` fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled batches. For more complex training procedures (meta-learning etc...) you will need to write your own fit_function fit_function_kwargs: Keyword arguments to pass to `fit_function` """ # Determine number of samples: num_batches = len(dataloader) batch_size = dataloader.batch_size callbacks = CallbackList([ DefaultCallback(), ] + (callbacks or []) + [ ProgressBarLogger(), ]) callbacks.set_model(model) callbacks.set_params({ 'num_batches': num_batches, 'batch_size': batch_size, 'verbose': verbose, 'metrics': (metrics or []), 'prepare_batch': prepare_batch, 'loss_fn': loss_fn, 'optimiser': optimiser }) if verbose: print('Begin training...') callbacks.on_train_begin() for epoch in range(1, epochs + 1): callbacks.on_epoch_begin(epoch) epoch_logs = {} for batch_index, batch in enumerate(dataloader): batch_logs = dict(batch=batch_index, size=(batch_size or 1)) callbacks.on_batch_begin(batch_index, batch_logs) x, y = prepare_batch(batch) fit_function_kwargs['stnmodel'] = stnmodel fit_function_kwargs['stnoptim'] = stnoptim fit_function_kwargs['args'] = args n_shot = fit_function_kwargs['n_shot'] k_way = fit_function_kwargs['k_way'] loss, y_pred, aug_imgs = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs) batch_logs['loss'] = loss.item() # Useful for viewing images for debugging (Doesn't work for maml) #TODO (kamal): customize for maml # if batch_index % 100 == 99: # writer.add_figure('episode', plot_classes_preds(aug_imgs, n_shot, k_way), # global_step=len(dataloader)*(epoch-1) + batch_index) # Loops through all metrics batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs) # Log training loss and categorical accuracy writer.add_scalar('Train_loss', batch_logs['loss'], len(dataloader) * (epoch - 1) + batch_index) writer.add_scalar('categorical_accuracy', batch_logs['categorical_accuracy'], len(dataloader) * (epoch - 1) + batch_index) callbacks.on_batch_end(batch_index, batch_logs) # Run on epoch end callbacks.on_epoch_end(epoch, epoch_logs) # Run on train end if verbose: print('Finished.') callbacks.on_train_end()
def fit(model: Union[Module, List[Module]], optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, verbose: bool = True, fit_function: Callable = gradient_step, n_models: int = 1, fit_function_kwargs: dict = {}): """Function to abstract away training loop. The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of common training functionality provided they are written as a subclass of voicemap.Callback (following the Keras API). # Arguments model: Model to be fitted. optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs epochs: Number of epochs of fitting to be performed dataloader: `torch.DataLoader` instance to fit the model to prepare_batch: Callable to perform any desired preprocessing metrics: Optional list of metrics to evaluate the model with callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. verbose: All print output is muted if this argument is `False` fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled batches. For more complex training procedures (meta-learning etc...) you will need to write your own fit_function fit_function_kwargs: Keyword arguments to pass to `fit_function` """ # Determine number of samples: num_batches = len(dataloader) batch_size = dataloader.batch_size fit_function_kwargs_logs = dict(fit_function_kwargs) fit_function_kwargs_logs['train'] = False fit_function_kwargs_logs['pred_fn'] = logmeanexp_preds callbacks = CallbackList([ DefaultCallback(), ] + (callbacks or []) + [ ProgressBarLogger(), ]) callbacks.set_model(model) callbacks.set_params({ 'num_batches': num_batches, 'batch_size': batch_size, 'verbose': verbose, 'metrics': (metrics or []), 'prepare_batch': prepare_batch, 'loss_fn': loss_fn, 'optimiser': optimiser, 'n_models': n_models }) if verbose: print('Begin training...') callbacks.on_train_begin() for epoch in range(1, epochs + 1): callbacks.on_epoch_begin(epoch) epoch_logs = {} for batch_index, batch in enumerate(dataloader): batch_logs = dict(batch=batch_index, size=(batch_size or 1)) callbacks.on_batch_begin(batch_index, batch_logs) x, y = prepare_batch(batch) # result = { # "meta_batch_loss": meta_batch_loss, # "task_predictions": task_predictions, # "models_losses": models_losses, # "models_predictions": models_predictions, # "mean_support_loss": mean_support_loss # } result = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs) loss = result['meta_batch_loss'] y_pred = result['task_predictions'] models_losses = result['models_losses'] models_preds = result['models_predictions'] support_loss = result['mean_support_loss'] batch_logs['loss'] = loss.item() batch_logs['support_loss'] = support_loss # Loops through all metrics batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs) task_preds = defaultdict(list) for model_pred in models_preds: for i, task in enumerate(model_pred): task_preds[i].append(task) # task_preds : {task_idx : [model_1_pred, model_2_pred, ....] } logprobs_pred = [] logprobs_loss = [] for task_idx, task_pred in task_preds.items(): y_pred_ = logmeanexp_preds(task_pred) logprobs_pred.append(y_pred_) y_pred_logprobs = torch.cat(logprobs_pred) # TODO: make it work with MixturePredLoss # with torch.no_grad(): # loss_logprobs = loss_fn(y_pred_logprobs, y).item() # # batch_logs['logprobs_loss'] = loss_logprobs batch_logs['logprobs_nll'] = nll_loss(y_pred_logprobs, y, reduction="mean").item() batch_logs = batch_metrics(model, y_pred_logprobs, y, metrics, batch_logs, 'logprobs') for i, (loss, y_pred) in enumerate(zip(models_losses, models_preds)): batch_logs[f'loss_{i}'] = nmean(loss) batch_logs[f'categorical_accuracy_{i}'] = NAMED_METRICS[ 'categorical_accuracy'](y, torch.cat(y_pred)) callbacks.on_batch_end(batch_index, batch_logs) # Run on epoch end callbacks.on_epoch_end(epoch, epoch_logs) # Run on train end if verbose: print('Finished.') callbacks.on_train_end()
def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, verbose: bool =True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}): """Function to abstract away training loop. The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of common training functionality provided they are written as a subclass of voicemap.Callback (following the Keras API). # Arguments model: Model to be fitted. optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs epochs: Number of epochs of fitting to be performed dataloader: `torch.DataLoader` instance to fit the model to prepare_batch: Callable to perform any desired preprocessing metrics: Optional list of metrics to evaluate the model with callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. verbose: All print output is muted if this argument is `False` fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled batches. For more complex training procedures (meta-learning etc...) you will need to write your own fit_function fit_function_kwargs: Keyword arguments to pass to `fit_function` """ # Determine number of samples: num_batches = len(dataloader) batch_size = dataloader.batch_size # default call back averages the bach accuracy and loss callbacks = CallbackList([DefaultCallback(), ] + (callbacks or []) + [ProgressBarLogger(), ]) # model and all other information has been passed to call back nothing else ot be done during function calls callbacks.set_model(model) callbacks.set_params({ 'num_batches': num_batches, 'batch_size': batch_size, 'verbose': verbose, 'metrics': (metrics or []), 'prepare_batch': prepare_batch, 'loss_fn': loss_fn, 'optimiser': optimiser }) if verbose: print('Begin training...') # creates a csv logger file callbacks.on_train_begin() for epoch in range(1, epochs+1): callbacks.on_epoch_begin(epoch) epoch_logs = {} for batch_index, batch in enumerate(dataloader): # for each new batch create a batch_log batch_logs = dict(batch=batch_index, size=(batch_size or 1)) # this does nothing for protonets except the progress bar callbacks.on_batch_begin(batch_index, batch_logs) # y here is of shape queries * k-way # y is in [0, k] x, y = prepare_batch(batch) # what we expect here is a loss for the above batch and the probabolities of the classes predicted # accuracy is determined on queries only loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs) batch_logs['loss'] = loss.item() # Loops through all metrics # for each episode per epoch what is the accuracy that number of corrects / total number of queries batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs) # this does nothing protonets except the progess bar # the categorical accuracy and loss we see during train refers to the queried samples accuracy callbacks.on_batch_end(batch_index, batch_logs) # evalfewshot is run only here after one complete epoch # Run on epoch end callbacks.on_epoch_end(epoch, epoch_logs) # Run on train end if verbose: print('Finished.') callbacks.on_train_end()
def fit_gan_few_shot(encoder: Module, generator: Module, classifier: Module, discriminator: Module, dataloader: DataLoader, params_str: str, k: int, n: int, epochs: int, prepare_batch: Callable, latent_sizeC: int, latent_sizeB: int, device, e_optimizer: Optimizer, g_optimizer: Optimizer, c_optimizer: Optimizer, d_optimizer: Optimizer, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, verbose: bool = True, is_complete=True): """Function to abstract away training loop. The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of common training functionality provided they are written as a subclass of voicemap.Callback (following the Keras API). # Arguments model: Model to be fitted. optimiser: Optimiser to calculate gradient step from loss epochs: Number of epochs of fitting to be performed dataloader: `torch.DataLoader` instance to fit the model to prepare_batch: Callable to perform any desired preprocessing metrics: Optional list of metrics to evaluate the model with callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. verbose: All print output is muted if this argument is `False` """ # Determine number of samples: num_batches = len(dataloader) batch_size = dataloader.batch_size callbacks = CallbackList([ DefaultCallback(), ] + (callbacks or []) + [ ProgressBarLogger(), ]) class EncoderClassifier(nn.Module): def forward(self, x): c, v = encoder(x) return classifier(c) encoderclassifier = EncoderClassifier().to(device, dtype=torch.double) callbacks.set_model(encoderclassifier) callbacks.set_params({ 'num_batches': num_batches, 'batch_size': batch_size, 'verbose': verbose, 'metrics': (metrics or []), 'prepare_batch': prepare_batch, 'loss_fn': criterionClass }) checkpoint_e = ModelCheckpoint( filepath=os.path.join(PATH, 'models', 'semantic_gan', str(params_str) + '_encoder.pth'), monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc') checkpoint_g = ModelCheckpoint( filepath=os.path.join(PATH, 'models', 'semantic_gan', str(params_str) + '_generator.pth'), monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc') checkpoint_c = ModelCheckpoint( filepath=os.path.join(PATH, 'models', 'semantic_gan', str(params_str) + '_classifier.pth'), monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc') checkpoint_d = ModelCheckpoint( filepath=os.path.join(PATH, 'models', 'semantic_gan', str(params_str) + '_discriminator.pth'), monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc') checkpoint_e.set_model(encoder) checkpoint_g.set_model(generator) checkpoint_c.set_model(classifier) checkpoint_d.set_model(discriminator) if verbose: print('Begin training...') callbacks.on_train_begin() for epoch in range(1, epochs + 1): callbacks.on_epoch_begin(epoch) epoch_logs = {} for batch_index, batch in enumerate(dataloader): batch_logs = dict(batch=batch_index, size=(batch_size or 1)) callbacks.on_batch_begin(batch_index, batch_logs) x, y = prepare_batch(batch) cl_loss, cl_score, dae_loss, ae_loss, cae_loss = gradient_step_gan_few_shot( e_optimizer, g_optimizer, c_optimizer, d_optimizer, x, y, device, encoder, generator, classifier, discriminator, latent_sizeB, latent_sizeC, epoch, is_complete=is_complete) batch_logs['cl_loss'] = cl_loss.item() #batch_logs['d_loss'] = d_loss.item() batch_logs['ae_loss'] = ae_loss.item() batch_logs['cae_loss'] = cae_loss.item() batch_logs['cl_score'] = cl_score.item() #batch_logs['real_score'] = real_score.item() #batch_logs['fake_score'] = fake_score.item() if is_complete: batch_logs['dae_loss'] = dae_loss.item() callbacks.on_batch_end(batch_index, batch_logs) # Run on epoch end callbacks.on_epoch_end(epoch, epoch_logs) checkpoint_e.on_epoch_end(epoch, epoch_logs) checkpoint_c.on_epoch_end(epoch, epoch_logs) checkpoint_g.on_epoch_end(epoch, epoch_logs) checkpoint_d.on_epoch_end(epoch, epoch_logs) # Run on train end if verbose: print('Finished.') callbacks.on_train_end()
def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader, prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None, verbose: bool = True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}): """Function to abstract away training loop. The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of common training functionality provided they are written as a subclass of voicemap.Callback (following the Keras API). # Arguments model: Model to be fitted. optimiser: Optimiser to calculate gradient step from loss loss_fn: Loss function to calculate between predictions and outputs epochs: Number of epochs of fitting to be performed dataloader: `torch.DataLoader` instance to fit the model to prepare_batch: Callable to perform any desired preprocessing metrics: Optional list of metrics to evaluate the model with callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model checkpointing, learning rate scheduling etc... See voicemap.callbacks for more. verbose: All print output is muted if this argument is `False` fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled batches. For more complex training procedures (meta-learning etc...) you will need to write your own fit_function fit_function_kwargs: Keyword arguments to pass to `fit_function` """ # Determine number of samples: num_batches = len(dataloader) print('num_batches: ', num_batches) batch_size = dataloader.batch_size device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") callbacks = CallbackList([ DefaultCallback(), ] + (callbacks or []) + [ ProgressBarLogger(), ]) callbacks.set_model(model) callbacks.set_params({ 'num_batches': num_batches, 'batch_size': batch_size, 'verbose': verbose, 'metrics': (metrics or []), 'prepare_batch': prepare_batch, 'loss_fn': loss_fn, 'optimiser': optimiser }) if verbose: print('Begin training...') callbacks.on_train_begin() for epoch in range(1, epochs + 1): callbacks.on_epoch_begin(epoch) epoch_logs = {} # train_iter = iter(dataloader) # first_batch = next(train_iter) # for obj in gc.get_objects(): # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # del obj torch.cuda.empty_cache() for batch_index, batch in enumerate(dataloader): batch_logs = dict(batch=batch_index, size=(batch_size or 1)) callbacks.on_batch_begin(batch_index, batch_logs) input_ids, attention_mask, label = prepare_batch(batch) input_ids = torch.squeeze(input_ids, dim=1) attention_mask = torch.squeeze(attention_mask, dim=1) input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) label = label.to(device) # print('input_ids shape: ', input_ids.size()) # print('attention_mask shape: ', attention_mask.size()) # print('label shape: ', label.size()) # input_ids = input_ids[:8,:] # attention_mask = attention_mask[:8,:] # label = label[:8] # print('input_ids shape: ', input_ids.size()) # print('attention_mask shape: ', attention_mask.size()) # print('label shape: ', label.size()) try: print('Before Loss/Pred') gpu_dict = get_gpu_info() print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'. format(gpu_dict['mem_total'], gpu_dict['mem_used'], gpu_dict['mem_used_percent'])) except: pass loss, y_pred = fit_function(model, optimiser, loss_fn, input_ids, attention_mask, label, **fit_function_kwargs) batch_logs['loss'] = loss.item() try: print('After Loss/Pred') gpu_dict = get_gpu_info() print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'. format(gpu_dict['mem_total'], gpu_dict['mem_used'], gpu_dict['mem_used_percent'])) except: pass # Loops through all metrics batch_logs = batch_metrics(model, y_pred, label, metrics, batch_logs) callbacks.on_batch_end(batch_index, batch_logs) # for obj in gc.get_objects(): # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # del obj torch.cuda.empty_cache() # Run on epoch end callbacks.on_epoch_end(epoch, epoch_logs) # Run on train end if verbose: print('Finished.') callbacks.on_train_end()