def __load_metrics(self): precision = Precision(average=False) recall = Recall(average=False) F1 = precision * recall * 2 / (precision + recall + 1e-20) F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1) confusion_matrix = ConfusionMatrix(self.n_class, average="recall") # TODO: Add metric by patient self.metrics = { 'accuracy': Accuracy(), "f1": F1, "confusion_matrix": confusion_matrix, "precision": precision.mean(), "recall": recall.mean(), 'loss': Loss(self.loss) }
def train(self, config, **kwargs): """Trains a given model specified in the config file or passed as the --model parameter. All options in the config file can be overwritten as needed by passing --PARAM Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}' :param config: yaml config file :param **kwargs: parameters to overwrite yaml config """ config_parameters = utils.parse_config_or_kwargs(config, **kwargs) outputdir = Path( config_parameters['outputpath'], config_parameters['model'], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex[:8])) # Early init because of creating dir checkpoint_handler = ModelCheckpoint( outputdir, 'run', n_saved=1, require_empty=False, create_dir=True, score_function=lambda engine: -engine.state.metrics['Loss'], save_as_state_dict=False, score_name='loss') logger = utils.getfile_outlogger(Path(outputdir, 'train.log')) logger.info("Storing files in {}".format(outputdir)) # utils.pprint_dict utils.pprint_dict(config_parameters, logger.info) logger.info("Running on device {}".format(DEVICE)) labels_df = pd.read_csv(config_parameters['trainlabel'], sep=' ') labels_df['encoded'], encoder = utils.encode_labels( labels=labels_df['bintype']) train_df, cv_df = utils.split_train_cv(labels_df) transform = utils.parse_transforms(config_parameters['transforms']) utils.pprint_dict({'Classes': encoder.classes_}, logger.info, formatter='pretty') utils.pprint_dict(transform, logger.info, formatter='pretty') if 'sampler' in config_parameters and config_parameters[ 'sampler'] == 'MinimumOccupancySampler': # Asserts that each "batch" contains at least one instance train_sampler = dataset.MinimumOccupancySampler( np.stack(train_df['encoded'].values)) sampling_kwargs = {"sampler": train_sampler, "shuffle": False} elif 'shuffle' in config_parameters and config_parameters['shuffle']: sampling_kwargs = {"shuffle": True} else: sampling_kwargs = {"shuffle": False} logger.info("Using Sampler {}".format(sampling_kwargs)) colname = config_parameters.get('colname', ('filename', 'encoded')) # trainloader = dataset.getdataloader( train_df, config_parameters['traindata'], transform=transform, batch_size=config_parameters['batch_size'], colname=colname, # For other datasets with different key names num_workers=config_parameters['num_workers'], **sampling_kwargs) cvdataloader = dataset.getdataloader( cv_df, config_parameters['traindata'], transform=None, shuffle=False, colname=colname, # For other datasets with different key names batch_size=config_parameters['batch_size'], num_workers=config_parameters['num_workers']) if 'pretrained' in config_parameters and config_parameters[ 'pretrained'] is not None: model = models.load_pretrained(config_parameters['pretrained'], outputdim=len(encoder.classes_)) else: model = getattr(models, config_parameters['model'], 'LightCNN')(inputdim=trainloader.dataset.datadim, outputdim=len(encoder.classes_), **config_parameters['model_args']) if config_parameters['optimizer'] == 'AdaBound': try: import adabound optimizer = adabound.AdaBound( model.parameters(), **config_parameters['optimizer_args']) except ImportError: logger.info( "Adabound package not found, install via pip install adabound. Using Adam instead" ) config_parameters['optimizer'] = 'Adam' config_parameters['optimizer_args'] = { } # Default adam is adabount not found else: optimizer = getattr( torch.optim, config_parameters['optimizer'], )(model.parameters(), **config_parameters['optimizer_args']) utils.pprint_dict(optimizer, logger.info, formatter='pretty') utils.pprint_dict(model, logger.info, formatter='pretty') if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1: logger.info("Using {} GPUs!".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) criterion = torch.nn.CrossEntropyLoss().to(DEVICE) model = model.to(DEVICE) precision = Precision() recall = Recall() f1_score = (precision * recall * 2 / (precision + recall)).mean() metrics = { 'Loss': Loss(criterion), 'Precision': precision.mean(), 'Recall': recall.mean(), 'Accuracy': Accuracy(), 'F1': f1_score, } # batch contains 3 elements, X,Y and filename. Filename is only used # during evaluation def _prep_batch(batch, device=DEVICE, non_blocking=False): x, y, _ = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking)) train_engine = create_supervised_trainer(model, optimizer=optimizer, loss_fn=criterion, prepare_batch=_prep_batch, device=DEVICE) inference_engine = create_supervised_evaluator( model, metrics=metrics, prepare_batch=_prep_batch, device=DEVICE) RunningAverage(output_transform=lambda x: x).attach( train_engine, 'run_loss') # Showing progressbar during training pbar = ProgressBar(persist=False) pbar.attach(train_engine, ['run_loss']) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.1) @inference_engine.on(Events.COMPLETED) def update_reduce_on_plateau(engine): val_loss = engine.state.metrics['Loss'] if 'ReduceLROnPlateau' == scheduler.__class__.__name__: scheduler.step(val_loss) else: scheduler.step() early_stop_handler = EarlyStopping( patience=5, score_function=lambda engine: -engine.state.metrics['Loss'], trainer=train_engine) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, early_stop_handler) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'encoder': encoder, 'config': config_parameters, }) @train_engine.on(Events.EPOCH_COMPLETED) def compute_validation_metrics(engine): inference_engine.run(cvdataloader) results = inference_engine.state.metrics output_str_list = [ "Validation Results - Epoch : {:<5}".format(engine.state.epoch) ] for metric in metrics: output_str_list.append("{} {:<5.3f}".format( metric, results[metric])) logger.info(" ".join(output_str_list)) pbar.n = pbar.last_print_n = 0 train_engine.run(trainloader, max_epochs=config_parameters['epochs']) return outputdir