def train(save_path, model, lr=0.1, batch_size=128, callbacks=[]): # Create dynamically dataset generators train, valid, test, meta_data = get_dataset(batch_size=batch_size) # Create dynamically model model = models.__dict__[model]() summary(model) loss_function = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=lr) # Create dynamically callbacks callbacks_constructed = [] for name in callbacks: clbk = get_callback(name, verbose=0) if clbk is not None: callbacks_constructed.append(clbk) # Pass everything to the training loop steps_per_epoch = (len(meta_data['x_train']) - 1) // batch_size + 1 training_loop(model=model, optimizer=optimizer, loss_function=loss_function, metrics=[acc], train=train, valid=test, meta_data=meta_data, steps_per_epoch=steps_per_epoch, save_path=save_path, config=_CONFIG, use_tb=True, custom_callbacks=callbacks_constructed)
def train(save_path, model, batch_size=128, seed=777, callbacks=[], resume=True, evaluate=True): # Create dynamically dataset generators train, valid, test, meta_data = get_dataset(batch_size=batch_size, seed=seed) # Create dynamically model model = models.__dict__[model]() summary(model) # Create dynamically callbacks callbacks_constructed = [] for name in callbacks: clbk = get_callback(name, verbose=0) if clbk is not None: callbacks_constructed.append(clbk) if not resume and os.path.exists(os.path.join(save_path, "last.ckpt")): raise IOError( "Please clear folder before running or pass train.resume=True") # Create module and pass to trianing checkpoint_callback = ModelCheckpoint( filepath=os.path.join(save_path, "weights"), verbose=True, save_last=True, # For resumability monitor='valid_acc', mode='max') pl_module = supervised_training.SupervisedLearning(model, meta_data=meta_data) trainer = training_loop(train, valid, pl_module=pl_module, checkpoint_callback=checkpoint_callback, callbacks=callbacks_constructed, save_path=save_path) # Evaluate if evaluate: results, = trainer.test(test_dataloaders=test) logger.info(results) with open(os.path.join(save_path, "eval_results.json"), "w") as f: json.dump(results, f)
def train(save_path, data_class, label_mode = 'multiclass_cancer_sides', batch_size=128, callbacks=['BreastDataLoader']): ''' data_class: 'data_with_segmentations_gin' or 'data_gin' ''' # Create dynamically dataset generators data_loader = data.__dict__[data_class](logger_breast_ori(save_path, 'output_log.log'), minibatch_size=batch_size) # Create dynamically callbacks callbacks_constructed = [] for name in callbacks: clbk = get_callback(name, verbose=0) if clbk is not None: callbacks_constructed.append(clbk) if data_loader.parameters['train_sampling_mode'] == 'normal': training_oversampled_indices = data_loader.data_list_training else: training_oversampled_indices = data_loader.train_sampler.sample_indices(data_loader.get_train_labels_cancer('multiclass_cancer_sides'), random_seed=0) steps_per_epoch = (len(training_oversampled_indices) - 1) // batch_size + 1 validation_steps = (len(data_loader.data_list_validation) - 1) // batch_size + 1 logger.info('samples_per_training_epoch=%d; steps_per_epoch=%d'%(len(training_oversampled_indices), steps_per_epoch)) logger.info('samples_per_evaluation_epoch=%d; validation_steps=%d'%(len(data_loader.data_list_validation), validation_steps)) training_loop(meta_data=None, label_mode=label_mode, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, data_loader=data_loader, save_path=save_path, config=_CONFIG, custom_callbacks=callbacks_constructed)
def train(save_path, model, lr_splitting_by=None, lrs=None, wd=0, lr=0.1, batch_size=128, n_epochs=100, weights=None, fb_method=False, callbacks=[], optimizer='sgd', scheduler=None, freeze_all_but_this_layer=None, mode='train'): # Create dynamically dataset generators train, valid, test, meta_data = get_chexnet_covid(batch_size=batch_size) # Create dynamically model model = models.__dict__[model]() summary(model) loss_function = torch.nn.BCELoss() if freeze_all_but_this_layer is not None: # First freeze all layers logger.info("Freezing all layers") for i, parameter in enumerate(model.parameters()): parameter.requires_grad = False # Unfreeze layers that matches for i, (name, parameter) in enumerate(model.named_parameters()): if name.startswith(freeze_all_but_this_layer): parameter.requires_grad = True logger.info("Unfreezing {}: {}".format(name, parameter.shape)) if optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd) elif optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) if scheduler == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, n_epochs) if lr_splitting_by is not None: optimizer, _ = create_optimizer(optimizer, model, lr_splitting_by, lrs) # Create dynamically callbacks callbacks_constructed = [] for name in callbacks: clbk = get_callback(name, verbose=0) if clbk is not None: print(name) callbacks_constructed.append(clbk) # Pass everything to the training loop if train is not None: steps_per_epoch = len(train) else: steps_per_epoch = None target_indice = None if fb_method: target_indice = weights.index(1) if 1 in weights else 0 elif weights is not None: target_indice = 0 if mode == 'train': assert train is not None, "please provide train data" assert valid is not None, "please provide validation data" training_loop( model=model, optimizer=optimizer, scheduler=scheduler, loss_function=loss_function, metrics=[acc_chexnet_covid], train=train, valid=valid, test=test, meta_data=meta_data, steps_per_epoch=steps_per_epoch, n_epochs=n_epochs, save_path=save_path, config=_CONFIG, use_tb=True, custom_callbacks=callbacks_constructed, fb_method=fb_method, target_indice=target_indice, ) else: assert test is not None, "please provide test data for evaluation" evaluation_loop( model=model, optimizer=optimizer, loss_function=loss_function, metrics=[acc_chexnet_covid], test=test, meta_data=meta_data, save_path=save_path, config=_CONFIG, custom_callbacks=callbacks_constructed, target_indice=target_indice, )
def train(save_path, model, datasets=['cifar10'], optimizer="SGD", data_seed=777, seed=777, batch_size=128, lr=0.0, wd=0.0, nesterov=False, checkpoint_monitor='val_categorical_accuracy:0', loss='ce', steps_per_epoch=-1, momentum=0.9, testing=False, testing_reload_best_val=True, callbacks=[]): np.random.seed(seed) # Create dataset generators (seeded) datasets = [ get_dataset(d, seed=data_seed, batch_size=batch_size) for d in datasets ] # Create model model = models.__dict__[model](input_shape=datasets[0][-1]['input_shape'], n_classes=datasets[0][-1]['num_classes']) logger.info("# of parameters " + str(sum([np.prod(p.shape) for p in model.trainable_weights]))) model.summary() if loss == 'ce': loss_function = tf.keras.losses.categorical_crossentropy else: raise NotImplementedError() if optimizer == "SGD": optimizer = SGD(learning_rate=lr, momentum=momentum, nesterov=nesterov) elif optimizer == "Adam": optimizer = Adam(learning_rate=lr) else: raise NotImplementedError() # Create callbacks callbacks_constructed = [] for name in callbacks: clbk = get_callback(name, verbose=0) if clbk is not None: callbacks_constructed.append(clbk) else: raise NotImplementedError(f"Did not find callback {name}") # Pass everything to the training loop metrics = [categorical_accuracy] if steps_per_epoch == -1: steps_per_epoch = (datasets[0][-1]['n_examples_train'] + batch_size - 1) // batch_size training_loop(model=model, optimizer=optimizer, loss_function=loss_function, metrics=metrics, datasets=datasets, weight_decay=wd, save_path=save_path, config=_CONFIG, steps_per_epoch=steps_per_epoch, use_tb=True, checkpoint_monitor=checkpoint_monitor, custom_callbacks=callbacks_constructed, seed=seed) if testing: if testing_reload_best_val: model = restore_model(model, os.path.join(save_path, "model_best_val.h5")) m_val = evaluate(model, [datasets[0][1]], loss_function, metrics) m_test = evaluate(model, [datasets[0][2]], loss_function, metrics) logger.info("Saving") eval_results = {} for k in m_test: eval_results['test_' + k] = float(m_test[k]) for k in m_val: eval_results['val_' + k] = float(m_val[k]) logger.info(eval_results) json.dump(eval_results, open(os.path.join(save_path, "eval_results.json"), "w"))