def train(args): if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(**vars(args)) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = optim.get_optimizer( filter(lambda p: p.requires_grad, model.parameters()), args) lr_scheduler = optim.get_scheduler(optimizer, args) if args.ckpt_path: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver loss_fn = nn.CrossEntropyLoss() train_loader = CIFARLoader('train', args.batch_size, args.num_workers) logger = TrainLogger(args, len(train_loader.dataset)) eval_loaders = [CIFARLoader('val', args.batch_size, args.num_workers)] evaluator = ModelEvaluator(eval_loaders, logger, args.max_eval, args.epochs_per_eval) saver = ModelSaver(**vars(args)) # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets in train_loader: logger.start_iter() with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) loss = loss_fn(logits, targets.to(args.device)) logger.log_iter(loss) optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter() metrics = evaluator.evaluate(model, args.device, logger.epoch) saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metrics.get(args.metric_name, None)) logger.end_epoch(metrics) optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
def train(args): train_loader = get_loader(args=args) if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] args.D_in = train_loader.D_in model = model_fn(**vars(args)) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = optim.get_optimizer( filter(lambda p: p.requires_grad, model.parameters()), args) lr_scheduler = optim.get_scheduler(optimizer, args) if args.ckpt_path: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver loss_fn = optim.get_loss_fn(args.loss_fn, args) logger = TrainLogger(args, len(train_loader.dataset)) eval_loaders = [ get_loader(args, phase='train', is_training=False), get_loader(args, phase='valid', is_training=False) ] evaluator = ModelEvaluator(args, eval_loaders, logger, args.max_eval, args.epochs_per_eval) saver = ModelSaver(**vars(args)) # Train model while not logger.is_finished_training(): logger.start_epoch() for src, tgt in train_loader: logger.start_iter() with torch.set_grad_enabled(True): pred_params = model.forward(src.to(args.device)) ages = src[:, 1] loss = loss_fn(pred_params, tgt.to(args.device), ages.to(args.device), args.use_intvl) #loss = loss_fn(pred_params, tgt.to(args.device), src.to(args.device), args.use_intvl) logger.log_iter(src, pred_params, tgt, loss) optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter() metrics = evaluator.evaluate(model, args.device, logger.epoch) # print(metrics) saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device,\ metric_val=metrics.get(args.metric_name, None)) logger.end_epoch(metrics=metrics)
def train(args): if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(**vars(args)) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Set up population-based training client pbt_client = PBTClient(args.pbt_server_url, args.pbt_server_port, args.pbt_server_key, args.pbt_config_path) # Get optimizer and scheduler parameters = model.module.parameters() optimizer = optim.get_optimizer(parameters, args, pbt_client) ModelSaver.load_optimizer(args.ckpt_path, args.gpu_ids, optimizer) # Get logger, evaluator, saver train_loader = DataLoader(args, 'train', is_training_set=True) eval_loaders = [DataLoader(args, 'valid', is_training_set=False)] evaluator = ModelEvaluator(eval_loaders, args.epochs_per_eval, args.max_eval, args.num_visuals, use_ten_crop=args.use_ten_crop) saver = ModelSaver(**vars(args)) for _ in range(args.num_epochs): optim.update_hyperparameters(model.module, optimizer, pbt_client.hyperparameters()) for inputs, targets in train_loader: with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) loss = F.binary_cross_entropy_with_logits(logits, targets.to(args.device)) optimizer.zero_grad() loss.backward() optimizer.step() metrics = evaluator.evaluate(model, args.device) metric_val = metrics.get(args.metric_name, None) ckpt_path = saver.save(model, args.model, optimizer, args.device, metric_val) pbt_client.save(ckpt_path, metric_val) if pbt_client.should_exploit(): # Exploit pbt_client.exploit() # Load model and optimizer parameters from exploited network model, ckpt_info = ModelSaver.load_model(pbt_client.parameters_path(), args.gpu_ids) model = model.to(args.device) model.train() ModelSaver.load_optimizer(pbt_client.parameters_path(), args.gpu_ids, optimizer) # Explore pbt_client.explore()
def train(args): if args.ckpt_path and not args.use_pretrained: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(**vars(args)) if args.use_pretrained: model.load_pretrained(args.ckpt_path, args.gpu_ids) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler if args.use_pretrained or args.fine_tune: parameters = model.module.fine_tuning_parameters( args.fine_tuning_boundary, args.fine_tuning_lr) else: parameters = model.parameters() optimizer = util.get_optimizer(parameters, args) lr_scheduler = util.get_scheduler(optimizer, args) if args.ckpt_path and not args.use_pretrained and not args.fine_tune: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver cls_loss_fn = util.get_loss_fn(is_classification=True, dataset=args.dataset, size_average=False) data_loader_fn = data_loader.__dict__[args.data_loader] train_loader = data_loader_fn(args, phase='train', is_training=True) logger = TrainLogger(args, len(train_loader.dataset), train_loader.dataset.pixel_dict) eval_loaders = [data_loader_fn(args, phase='val', is_training=False)] evaluator = ModelEvaluator(args.do_classify, args.dataset, eval_loaders, logger, args.agg_method, args.num_visuals, args.max_eval, args.epochs_per_eval) saver = ModelSaver(args.save_dir, args.epochs_per_save, args.max_ckpts, args.best_ckpt_metric, args.maximize_metric) # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, target_dict in train_loader: logger.start_iter() with torch.set_grad_enabled(True): inputs.to(args.device) cls_logits = model.forward(inputs) cls_targets = target_dict['is_abnormal'] cls_loss = cls_loss_fn(cls_logits, cls_targets.to(args.device)) loss = cls_loss.mean() logger.log_iter(inputs, cls_logits, target_dict, cls_loss.mean(), optimizer) optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter() util.step_scheduler(lr_scheduler, global_step=logger.global_step) metrics, curves = evaluator.evaluate(model, args.device, logger.epoch) saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metrics.get(args.best_ckpt_metric, None)) logger.end_epoch(metrics, curves) util.step_scheduler(lr_scheduler, metrics, epoch=logger.epoch, best_ckpt_metric=args.best_ckpt_metric)
def train(args): """Run training loop with the given args. The function consists of the following steps: 1. Load model: gets the model from a checkpoint or from models/models.py. 2. Load optimizer and learning rate scheduler. 3. Get data loaders and class weights. 4. Get loss functions: cross entropy loss and weighted loss functions. 5. Get logger, evaluator, and saver. 6. Run training loop, evaluate and save model periodically. """ model_args = args.model_args logger_args = args.logger_args optim_args = args.optim_args data_args = args.data_args transform_args = args.transform_args task_sequence = TASK_SEQUENCES[data_args.task_sequence] # Get model if model_args.ckpt_path: model_args.pretrained = False model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[model_args.model] model = model_fn(task_sequence, model_args) if model_args.hierarchy: model = models.HierarchyWrapper(model, task_sequence) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = util.get_optimizer(model.parameters(), optim_args) lr_scheduler = util.get_scheduler(optimizer, optim_args) if model_args.ckpt_path: ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler) # Get loaders and class weights train_csv_name = 'train' if data_args.uncertain_map_path is not None: train_csv_name = data_args.uncertain_map_path #TODO: Remove this when we decide which transformation to use in the end #transforms_imgaug = ImgAugTransform() train_loader = get_loader(data_args, transform_args, train_csv_name, task_sequence, data_args.su_train_frac, data_args.nih_train_frac, data_args.pocus_train_frac, data_args.tcga_train_frac, 0, 0, args.batch_size, frontal_lateral=model_args.frontal_lateral, is_training=True, shuffle=True, transform=model_args.transform, normalize=model_args.normalize) eval_loaders = get_eval_loaders(data_args, transform_args, task_sequence, args.batch_size, frontal_lateral=model_args.frontal_lateral, normalize=model_args.normalize) class_weights = train_loader.dataset.class_weights print(" class weights:") print(class_weights) # Get loss functions uw_loss_fn = get_loss_fn('cross_entropy', args.device, model_args.model_uncertainty, args.has_tasks_missing, class_weights=class_weights) w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty, args.has_tasks_missing, mask_uncertain=False, class_weights=class_weights) # Get logger, evaluator and saver logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size, len(train_loader.dataset), args.device) eval_args = {} eval_args['num_visuals'] = logger_args.num_visuals eval_args['iters_per_eval'] = logger_args.iters_per_eval eval_args['has_missing_tasks'] = args.has_tasks_missing eval_args['model_uncertainty'] = model_args.model_uncertainty eval_args['class_weights'] = class_weights eval_args['max_eval'] = logger_args.max_eval eval_args['device'] = args.device eval_args['optimizer'] = args.optimizer evaluator = get_evaluator('classification', eval_loaders, logger, eval_args) print("Eval Loaders: %d" % len(eval_loaders)) saver = ModelSaver(**vars(logger_args)) metrics = None lr_step = 0 # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets, info_dict in train_loader: logger.start_iter() # Evaluate and save periodically metrics, curves = evaluator.evaluate(model, args.device, logger.global_step) logger.plot_metrics(metrics) metric_val = metrics.get(logger_args.metric_name, None) assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metric_val) lr_step = util.step_scheduler( lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name) # Input: [batch_size, channels, width, height] with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) unweighted_loss = uw_loss_fn(logits, targets.to(args.device)) weighted_loss = w_loss_fn(logits, targets.to( args.device)) if w_loss_fn else None logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer) optimizer.zero_grad() if args.loss_fn == 'weighted_loss': weighted_loss.backward() else: unweighted_loss.backward() optimizer.step() logger.end_iter() logger.end_epoch(metrics, optimizer)
def train(args): if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(pretrained=args.pretrained) if args.pretrained: model.fc = nn.Linear(model.fc.in_features, args.num_classes) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler parameters = optim.get_parameters(model.module, args) optimizer = optim.get_optimizer(parameters, args) lr_scheduler = optim.get_scheduler(optimizer, args) if args.ckpt_path: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver loss_fn = nn.CrossEntropyLoss() train_loader = WhiteboardLoader(args.data_dir, 'train', args.batch_size, shuffle=True, do_augment=True, num_workers=args.num_workers) logger = TrainLogger(args, len(train_loader.dataset)) eval_loaders = [ WhiteboardLoader(args.data_dir, 'val', args.batch_size, shuffle=False, do_augment=False, num_workers=args.num_workers) ] evaluator = ModelEvaluator(eval_loaders, logger, args.epochs_per_eval, args.max_eval, args.num_visuals) saver = ModelSaver(**vars(args)) # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets, paths in train_loader: logger.start_iter() with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) loss = loss_fn(logits, targets.to(args.device)) logger.log_iter(inputs, logits, targets, paths, loss) optimizer.zero_grad() loss.backward() optimizer.step() optim.step_scheduler(lr_scheduler, global_step=logger.global_step) logger.end_iter() metrics = evaluator.evaluate(model, args.device, logger.epoch) saver.save(logger.epoch, model, args.model, optimizer, lr_scheduler, args.device, metric_val=metrics.get(args.metric_name, None)) logger.end_epoch(metrics) optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
def train(args): """Run training loop with the given args. The function consists of the following steps: 1. Load model: gets the model from a checkpoint or from models/models.py. 2. Load optimizer and learning rate scheduler. 3. Get data loaders and class weights. 4. Get loss functions: cross entropy loss and weighted loss functions. 5. Get logger, evaluator, and saver. 6. Run training loop, evaluate and save model periodically. """ model_args = args.model_args logger_args = args.logger_args optim_args = args.optim_args data_args = args.data_args transform_args = args.transform_args task_sequence = TASK_SEQUENCES[data_args.task_sequence] print('gpus: ', args.gpu_ids) # Get model if model_args.ckpt_path: model_args.pretrained = False model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args) if not logger_args.restart_epoch_count: args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[model_args.model] model = model_fn(task_sequence, model_args) num_covars = len(model_args.covar_list.split(';')) model.transform_model_shape(len(task_sequence), num_covars) if model_args.hierarchy: model = models.HierarchyWrapper(model, task_sequence) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = util.get_optimizer(model.parameters(), optim_args) lr_scheduler = util.get_scheduler(optimizer, optim_args) # The optimizer is loaded from the ckpt if one exists and the new model # architecture is the same as the old one (classifier is not transformed). if model_args.ckpt_path and not model_args.transform_classifier: ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler) # Get loaders and class weights train_csv_name = 'train' if data_args.uncertain_map_path is not None: train_csv_name = data_args.uncertain_map_path # Put all CXR training fractions into one dictionary and pass it to the loader cxr_frac = {'pocus': data_args.pocus_train_frac, 'hocus': data_args.hocus_train_frac, 'pulm': data_args.pulm_train_frac} train_loader = get_loader(data_args, transform_args, train_csv_name, task_sequence, data_args.su_train_frac, data_args.nih_train_frac, cxr_frac, data_args.tcga_train_frac, args.batch_size, frontal_lateral=model_args.frontal_lateral, is_training=True, shuffle=True, covar_list=model_args.covar_list, fold_num=data_args.fold_num) eval_loaders = get_eval_loaders(data_args, transform_args, task_sequence, args.batch_size, frontal_lateral=model_args.frontal_lateral, covar_list=model_args.covar_list, fold_num=data_args.fold_num) class_weights = train_loader.dataset.class_weights # Get loss functions uw_loss_fn = get_loss_fn(args.loss_fn, args.device, model_args.model_uncertainty, args.has_tasks_missing, class_weights=class_weights) w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty, args.has_tasks_missing, class_weights=class_weights) # Get logger, evaluator and saver logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size, len(train_loader.dataset), args.device, normalization=transform_args.normalization) eval_args = {} eval_args['num_visuals'] = logger_args.num_visuals eval_args['iters_per_eval'] = logger_args.iters_per_eval eval_args['has_missing_tasks'] = args.has_tasks_missing eval_args['model_uncertainty'] = model_args.model_uncertainty eval_args['class_weights'] = class_weights eval_args['max_eval'] = logger_args.max_eval eval_args['device'] = args.device eval_args['optimizer'] = optimizer evaluator = get_evaluator('classification', eval_loaders, logger, eval_args) print("Eval Loaders: %d" % len(eval_loaders)) saver = ModelSaver(**vars(logger_args)) metrics = None lr_step = 0 # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets, info_dict, covars in train_loader: logger.start_iter() # Evaluate and save periodically metrics, curves = evaluator.evaluate(model, args.device, logger.global_step) logger.plot_metrics(metrics) metric_val = metrics.get(logger_args.metric_name, None) assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metric_val, covar_list=model_args.covar_list) lr_step = util.step_scheduler(lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name) # Input: [batch_size, channels, width, height] with torch.set_grad_enabled(True): # with torch.autograd.set_detect_anomaly(True): logits = model.forward([inputs.to(args.device), covars]) # Scale up TB so that it's loss is counted for more if upweight_tb is True. if model_args.upweight_tb is True: tb_targets = targets.narrow(1, 0, 1) findings_targets = targets.narrow(1, 1, targets.shape[1] - 1) tb_targets = tb_targets.repeat(1, targets.shape[1] - 1) new_targets = torch.cat((tb_targets, findings_targets), 1) tb_logits = logits.narrow(1, 0, 1) findings_logits = logits.narrow(1, 1, logits.shape[1] - 1) tb_logits = tb_logits.repeat(1, logits.shape[1] - 1) new_logits = torch.cat((tb_logits, findings_logits), 1) else: new_logits = logits new_targets = targets unweighted_loss = uw_loss_fn(new_logits, new_targets.to(args.device)) weighted_loss = w_loss_fn(logits, targets.to(args.device)) if w_loss_fn else None logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer) optimizer.zero_grad() if args.loss_fn == 'weighted_loss': weighted_loss.backward() else: unweighted_loss.backward() optimizer.step() logger.end_iter() logger.end_epoch(metrics, optimizer)