def train(args): if args.ckpt_path and not args.use_pretrained: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(**vars(args)) if args.use_pretrained: model.load_pretrained(args.ckpt_path, args.gpu_ids) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler if args.use_pretrained or args.fine_tune: parameters = model.module.fine_tuning_parameters( args.fine_tuning_boundary, args.fine_tuning_lr) else: parameters = model.parameters() optimizer = util.get_optimizer(parameters, args) lr_scheduler = util.get_scheduler(optimizer, args) if args.ckpt_path and not args.use_pretrained and not args.fine_tune: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver cls_loss_fn = util.get_loss_fn(is_classification=True, dataset=args.dataset, size_average=False) data_loader_fn = data_loader.__dict__[args.data_loader] train_loader = data_loader_fn(args, phase='train', is_training=True) logger = TrainLogger(args, len(train_loader.dataset), train_loader.dataset.pixel_dict) eval_loaders = [data_loader_fn(args, phase='val', is_training=False)] evaluator = ModelEvaluator(args.do_classify, args.dataset, eval_loaders, logger, args.agg_method, args.num_visuals, args.max_eval, args.epochs_per_eval) saver = ModelSaver(args.save_dir, args.epochs_per_save, args.max_ckpts, args.best_ckpt_metric, args.maximize_metric) # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, target_dict in train_loader: logger.start_iter() with torch.set_grad_enabled(True): inputs.to(args.device) cls_logits = model.forward(inputs) cls_targets = target_dict['is_abnormal'] cls_loss = cls_loss_fn(cls_logits, cls_targets.to(args.device)) loss = cls_loss.mean() logger.log_iter(inputs, cls_logits, target_dict, cls_loss.mean(), optimizer) optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter() util.step_scheduler(lr_scheduler, global_step=logger.global_step) metrics, curves = evaluator.evaluate(model, args.device, logger.epoch) saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metrics.get(args.best_ckpt_metric, None)) logger.end_epoch(metrics, curves) util.step_scheduler(lr_scheduler, metrics, epoch=logger.epoch, best_ckpt_metric=args.best_ckpt_metric)
def train(args): # Get loader for outer loop training loader = get_loader(args) target_image_shape = loader.dataset.target_image_shape setattr(args, 'target_image_shape', target_image_shape) # Load model model_fn = models.__dict__[args.model] model = model_fn(**vars(args)) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Print model parameters print('Model parameters: name, size, mean, std') for name, param in model.named_parameters(): print(name, param.size(), torch.mean(param), torch.std(param)) # Get optimizer and loss parameters = model.parameters() optimizer = util.get_optimizer(parameters, args) loss_fn = util.get_loss_fn(args.loss_fn, args) z_loss_fn = util.get_loss_fn(args.loss_fn, args) # Get logger, saver logger = TrainLogger(args) saver = ModelSaver(args) print(f'Logs: {logger.log_dir}') print(f'Ckpts: {args.save_dir}') # Train model logger.log_hparams(args) batch_size = args.batch_size while not logger.is_finished_training(): logger.start_epoch() for input_noise, target_image, mask, z_test_target, z_test in loader: logger.start_iter() if torch.cuda.is_available(): input_noise = input_noise.to(args.device) #.cuda() target_image = target_image.cuda() mask = mask.cuda() z_test = z_test.cuda() z_test_target = z_test_target.cuda() masked_target_image = target_image * mask obscured_target_image = target_image * (1.0 - mask) # Input is noise tensor, target is image model.train() with torch.set_grad_enabled(True): if args.use_intermediate_logits: logits = model.forward(input_noise).float() probs = F.sigmoid(logits) # Debug logits and diffs logger.debug_visualize( [logits, logits * mask, logits * (1.0 - mask)], unique_suffix='logits-train') else: probs = model.forward(input_noise).float() # With backprop, calculate (1) masked loss, loss when mask is applied. # Loss is done elementwise without reduction, so must take mean after. # Easier for debugging. masked_probs = probs * mask masked_loss = torch.zeros(1, requires_grad=True).to(args.device) masked_loss = loss_fn(masked_probs, masked_target_image).mean() masked_loss.backward() optimizer.step() optimizer.zero_grad() # Without backprop, calculate (2) full loss on the entire image, # And (3) the obscured loss, region obscured by mask. model.eval() with torch.no_grad(): if args.use_intermediate_logits: logits_eval = model.forward(input_noise).float() probs_eval = F.sigmoid(logits_eval) # Debug logits and diffs logger.debug_visualize([ logits_eval, logits_eval * mask, logits_eval * (1.0 - mask) ], unique_suffix='logits-eval') else: probs_eval = model.forward(input_noise).float() masked_probs_eval = probs_eval * mask masked_loss_eval = torch.zeros(1) masked_loss_eval = loss_fn(masked_probs_eval, masked_target_image).mean() full_loss_eval = torch.zeros(1) full_loss_eval = loss_fn(probs_eval, target_image).mean() obscured_probs_eval = probs_eval * (1.0 - mask) obscured_loss_eval = torch.zeros(1) obscured_loss_eval = loss_fn(obscured_probs_eval, obscured_target_image).mean() # With backprop on only the input z, (4) run one step of z-test and get z-loss z_optimizer = util.get_optimizer([z_test.requires_grad_()], args) with torch.set_grad_enabled(True): if args.use_intermediate_logits: z_logits = model.forward(z_test).float() z_probs = F.sigmoid(z_logits) else: z_probs = model.forward(z_test).float() z_loss = torch.zeros(1, requires_grad=True).to(args.device) z_loss = z_loss_fn(z_probs, z_test_target).mean() z_loss.backward() z_optimizer.step() z_optimizer.zero_grad() if z_loss < args.max_z_test_loss: # TODO: include this part into the metrics/saver stuff below # Save MSE on obscured region final_metrics = {'final/score': obscured_loss_eval.item()} logger._log_scalars(final_metrics) print('z loss', z_loss) print('Final MSE value', obscured_loss_eval) # TODO: Make a function for metrics - or at least make sure dict includes all possible best ckpt metrics metrics = {'masked_loss': masked_loss.item()} saver.save(logger.global_step, model, optimizer, args.device, metric_val=metrics.get(args.best_ckpt_metric, None)) # Log both train and eval model settings, and visualize their outputs logger.log_status( inputs=input_noise, targets=target_image, probs=probs, masked_probs=masked_probs, masked_loss=masked_loss, probs_eval=probs_eval, masked_probs_eval=masked_probs_eval, obscured_probs_eval=obscured_probs_eval, masked_loss_eval=masked_loss_eval, obscured_loss_eval=obscured_loss_eval, full_loss_eval=full_loss_eval, z_target=z_test_target, z_probs=z_probs, z_loss=z_loss, save_preds=args.save_preds, ) logger.end_iter() logger.end_epoch() # Last log after everything completes logger.log_status( inputs=input_noise, targets=target_image, probs=probs, masked_probs=masked_probs, masked_loss=masked_loss, probs_eval=probs_eval, masked_probs_eval=masked_probs_eval, obscured_probs_eval=obscured_probs_eval, masked_loss_eval=masked_loss_eval, obscured_loss_eval=obscured_loss_eval, full_loss_eval=full_loss_eval, z_target=z_test_target, z_probs=z_probs, z_loss=z_loss, save_preds=args.save_preds, force_visualize=True, )
def train(args): # Get model model = models.__dict__[args.model](args) if args.ckpt_path: model = ModelSaver.load_model(model, args.ckpt_path, args.gpu_ids, is_training=True) model = model.to(args.device) model.train() # Get loader, logger, and saver train_loader, val_loader = get_data_loaders(args) logger = TrainLogger(args, model, dataset_len=len(train_loader.dataset)) saver = ModelSaver(args.save_dir, args.max_ckpts, metric_name=args.metric_name, maximize_metric=args.maximize_metric, keep_topk=True) # Train while not logger.is_finished_training(): logger.start_epoch() for batch in train_loader: logger.start_iter() # Train over one batch model.set_inputs(batch['src'], batch['tgt']) model.train_iter() logger.end_iter() # Evaluate if logger.global_step % args.iters_per_eval < args.batch_size: criteria = {'MSE_src2tgt': mse, 'MSE_tgt2src': mse} stats = evaluate(model, val_loader, criteria) logger.log_scalars({'val_' + k: v for k, v in stats.items()}) saver.save(logger.global_step, model, stats[args.metric_name], args.device) logger.end_epoch()
def train(args): """Train model. Args: args: Command line arguments. model: Classifier model to train. """ # Set up model model = models.__dict__[args.model](**vars(args)) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) # Set up data loader train_loader, test_loader, classes = get_cifar_loaders( args.batch_size, args.num_workers) # Set up optimizer optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.sgd_momentum, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_gamma) loss_fn = nn.CrossEntropyLoss().to(args.device) # Set up checkpoint saver saver = ModelSaver(model, optimizer, scheduler, args.save_dir, {'model': args.model}, max_to_keep=args.max_ckpts, device=args.device) # Train logger = TrainLogger(args, len(train_loader.dataset)) while not logger.is_finished_training(): logger.start_epoch() # Train for one epoch model.train() for inputs, labels in train_loader: logger.start_iter() with torch.set_grad_enabled(True): # Forward outputs = model.forward(inputs.to(args.device)) loss = loss_fn(outputs, labels.to(args.device)) loss_item = loss.item() # Backward optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter({'loss': loss_item}) # Evaluate on validation set val_loss = evaluate(model, test_loader, loss_fn, device=args.device) logger.write('[epoch {}]: val_loss: {:.3g}'.format( logger.epoch, val_loss)) logger.write_summaries({'loss': val_loss}, phase='val') if logger.epoch in args.save_epochs: saver.save(logger.epoch, val_loss) logger.end_epoch() scheduler.step()
def main(): parser = ArgParser() args = parser.parse_args() gen = Generator(args.latent_dim).to(args.device) disc = Discriminator().to(args.device) if args.device != 'cpu': gen = nn.DataParallel(gen, args.gpu_ids) disc = nn.DataParallel(disc, args.gpu_ids) # gen = gen.apply(weights_init) # disc = disc.apply(weights_init) gen_opt = torch.optim.RMSprop(gen.parameters(), lr=args.lr) disc_opt = torch.optim.RMSprop(disc.parameters(), lr=args.lr) gen_scheduler = torch.optim.lr_scheduler.LambdaLR(gen_opt, lr_lambda=lr_lambda(args.num_epochs)) disc_scheduler = torch.optim.lr_scheduler.LambdaLR(disc_opt, lr_lambda=lr_lambda(args.num_epochs)) disc_loss_fn = DiscriminatorLoss().to(args.device) gen_loss_fn = GeneratorLoss().to(args.device) # dataset = Dataset() dataset = MNISTDataset() loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) logger = TrainLogger(args, len(loader), phase=None) logger.log_hparams(args) if args.privacy_noise_multiplier != 0: privacy_engine = PrivacyEngine( disc, batch_size=args.batch_size, sample_size=len(dataset), alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)), noise_multiplier=.8, max_grad_norm=0.02, batch_first=True, ) privacy_engine.attach(disc_opt) privacy_engine.to(args.device) for epoch in range(args.num_epochs): logger.start_epoch() for cur_step, img in enumerate(tqdm(loader, dynamic_ncols=True)): logger.start_iter() img = img.to(args.device) fake, disc_loss = None, None for _ in range(args.step_train_discriminator): disc_opt.zero_grad() fake_noise = get_noise(args.batch_size, args.latent_dim, device=args.device) fake = gen(fake_noise) disc_loss = disc_loss_fn(img, fake, disc) disc_loss.backward() disc_opt.step() gen_opt.zero_grad() fake_noise_2 = get_noise(args.batch_size, args.latent_dim, device=args.device) fake_2 = gen(fake_noise_2) gen_loss = gen_loss_fn(img, fake_2, disc) gen_loss.backward() gen_opt.step() if args.privacy_noise_multiplier != 0: epsilon, best_alpha = privacy_engine.get_privacy_spent(args.privacy_delta) logger.log_iter_gan_from_latent_vector(img, fake, gen_loss, disc_loss, epsilon if args.privacy_noise_multiplier != 0 else 0) logger.end_iter() logger.end_epoch() gen_scheduler.step() disc_scheduler.step()
def main(): args = easydict.EasyDict({ # "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data/gan", "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data", "save_dir": "./", "prefix": "test", "workers": 8, "batch_size": 128, "image_size": 32, # "image_size": 28, # "nc": 3, "nc": 1, "nz": 100, "ngf": 32, "ndf": 32, # "ngf": 28, # "ndf": 64, "epochs": 1, "lr": 0.0002, "beta1": 0.5, "gpu": 7, "use_cuda": True, "feature_matching": True, "mini_batch": True, "iters": 50000, "label_batch_size": 100, "unlabel_batch_size": 100, "test_batch_size": 10, "out_dir": './result', "log_interval": 500, "label_num": 20 }) manualSeed = 999 np.random.seed(manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed) device = torch.device( 'cuda:{}'.format(args.gpu) if args.use_cuda else 'cpu') # transform = transforms.Compose([ # transforms.Resize(args.image_size), # transforms.CenterCrop(args.image_size), # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # # dataset = dset.ImageFolder(root=args.dataroot, transform=transform) # dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, # shuffle=True, num_workers=args.workers) data_iterators = dataset.get_iters(root_path=args.dataroot, l_batch_size=args.label_batch_size, ul_batch_size=args.unlabel_batch_size, test_batch_size=args.test_batch_size, workers=args.workers, n_labeled=args.label_num) trainloader_label = data_iterators['labeled'] trainloader_unlabel = data_iterators['unlabeled'] testloader = data_iterators['test'] # Generator用のモデルのインスタンス作成 netG = net.Generator(args.nz, args.ngf, args.nc).to(device) # Generator用のモデルの初期値を設定 netG.apply(net.weights_init) # Discriminator用のモデルのインスタンス作成 netD = net.Discriminator(args.nc, args.ndf, device, args.batch_size, args.mini_batch).to(device) # Discriminator用のモデルの初期値を設定 netD.apply(net.weights_init) # BCE Loss classのインスタンスを作成 criterionD = nn.CrossEntropyLoss() # criterionD = nn.BCELoss() if args.feature_matching is True: criterionG = nn.MSELoss(reduction='elementwise_mean') else: criterionG = nn.BCELoss() # Generatorに入力するノイズをバッチごとに作成 (バッチ数は64) # これはGeneratorの結果を描画するために使用する fixed_noise = torch.randn(64, args.nz, 1, 1, device=device) # 最適化関数のインスタンスを作成 optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) logger = TrainLogger(args) r = run.NNRun(netD, netG, optimizerD, optimizerG, criterionD, criterionG, device, fixed_noise, logger, args) # 学習 # r.train(dataloader) r.train(trainloader_label, trainloader_unlabel, testloader)
def train(args): if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(**vars(args)) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = optim.get_optimizer( filter(lambda p: p.requires_grad, model.parameters()), args) lr_scheduler = optim.get_scheduler(optimizer, args) if args.ckpt_path: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver loss_fn = nn.CrossEntropyLoss() train_loader = CIFARLoader('train', args.batch_size, args.num_workers) logger = TrainLogger(args, len(train_loader.dataset)) eval_loaders = [CIFARLoader('val', args.batch_size, args.num_workers)] evaluator = ModelEvaluator(eval_loaders, logger, args.max_eval, args.epochs_per_eval) saver = ModelSaver(**vars(args)) # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets in train_loader: logger.start_iter() with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) loss = loss_fn(logits, targets.to(args.device)) logger.log_iter(loss) optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter() metrics = evaluator.evaluate(model, args.device, logger.epoch) saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metrics.get(args.metric_name, None)) logger.end_epoch(metrics) optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
def main(args): write_args(args) # Set up main device and scale batch size device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu' print(device) # Set random seeds random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Build Model print("Building model...") model_fn = models.__dict__[args.model] model = model_fn(args, device) model = nn.DataParallel(model, args.gpu_ids) model = model.to(device) # Loss fn loss_fn = get_loss(args.model) # Data loaders train_loader = get_dataloader(args, "train") val_loader = get_dataloader(args, "val") # Optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) # Logger and Resume if args.resume: resume_path = os.path.join(args.save_dir, "current.pth.tar") print("Resuming from checkpoint at {}".format(resume_path)) checkpoint = torch.load(resume_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] start_iter = checkpoint['iter'] global_step = start_epoch * len(train_loader) logger = TrainLogger(args, start_epoch, global_step) logger.best_val_loss = checkpoint['val_loss'] print(start_epoch) print(start_iter) else: start_epoch = 0 global_step = 0 start_iter = 0 logger = TrainLogger(args, start_epoch, global_step) # Sampler sampler = get_sampler(args.model, 0, 16, args.size, args.input_c, args.save_dir, device) for i in range(start_epoch, args.num_epochs): # Train model.train() logger.start_epoch() for j, image in enumerate(train_loader): if j < start_iter: logger.end_iter() continue # Sample and Eval if j % 250 == 0: print("Sampling...") sampler.sample(model, i, j) with torch.no_grad(): logger.val_loss_meter.reset() model.eval() for image in tqdm(val_loader): image = image.to(device) output = model(image) loss = loss_fn(output, image) logger.val_loss_meter.update(loss) logger.has_improved(model, optimizer, j) logger._log_scalars({'val-loss': logger.val_loss_meter.avg}) model.train() logger.start_iter() image = image.to(device) optimizer.zero_grad() output = model(image) loss = loss_fn(output, image) loss.backward() for group in optimizer.param_groups: utils.clip_grad_norm_(group['params'], args.max_grad_norm, 2) optimizer.step() logger.log_iter(loss) logger.end_iter() logger.end_epoch(None, optimizer) start_iter = 0
def train(args): train_loader = get_loader(args=args) if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] args.D_in = train_loader.D_in model = model_fn(**vars(args)) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = optim.get_optimizer( filter(lambda p: p.requires_grad, model.parameters()), args) lr_scheduler = optim.get_scheduler(optimizer, args) if args.ckpt_path: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver loss_fn = optim.get_loss_fn(args.loss_fn, args) logger = TrainLogger(args, len(train_loader.dataset)) eval_loaders = [ get_loader(args, phase='train', is_training=False), get_loader(args, phase='valid', is_training=False) ] evaluator = ModelEvaluator(args, eval_loaders, logger, args.max_eval, args.epochs_per_eval) saver = ModelSaver(**vars(args)) # Train model while not logger.is_finished_training(): logger.start_epoch() for src, tgt in train_loader: logger.start_iter() with torch.set_grad_enabled(True): pred_params = model.forward(src.to(args.device)) ages = src[:, 1] loss = loss_fn(pred_params, tgt.to(args.device), ages.to(args.device), args.use_intvl) #loss = loss_fn(pred_params, tgt.to(args.device), src.to(args.device), args.use_intvl) logger.log_iter(src, pred_params, tgt, loss) optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter() metrics = evaluator.evaluate(model, args.device, logger.epoch) # print(metrics) saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device,\ metric_val=metrics.get(args.metric_name, None)) logger.end_epoch(metrics=metrics)
def train(args): """Run training loop with the given args. The function consists of the following steps: 1. Load model: gets the model from a checkpoint or from models/models.py. 2. Load optimizer and learning rate scheduler. 3. Get data loaders and class weights. 4. Get loss functions: cross entropy loss and weighted loss functions. 5. Get logger, evaluator, and saver. 6. Run training loop, evaluate and save model periodically. """ model_args = args.model_args logger_args = args.logger_args optim_args = args.optim_args data_args = args.data_args transform_args = args.transform_args task_sequence = TASK_SEQUENCES[data_args.task_sequence] # Get model if model_args.ckpt_path: model_args.pretrained = False model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[model_args.model] model = model_fn(task_sequence, model_args) if model_args.hierarchy: model = models.HierarchyWrapper(model, task_sequence) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = util.get_optimizer(model.parameters(), optim_args) lr_scheduler = util.get_scheduler(optimizer, optim_args) if model_args.ckpt_path: ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler) # Get loaders and class weights train_csv_name = 'train' if data_args.uncertain_map_path is not None: train_csv_name = data_args.uncertain_map_path #TODO: Remove this when we decide which transformation to use in the end #transforms_imgaug = ImgAugTransform() train_loader = get_loader(data_args, transform_args, train_csv_name, task_sequence, data_args.su_train_frac, data_args.nih_train_frac, data_args.pocus_train_frac, data_args.tcga_train_frac, 0, 0, args.batch_size, frontal_lateral=model_args.frontal_lateral, is_training=True, shuffle=True, transform=model_args.transform, normalize=model_args.normalize) eval_loaders = get_eval_loaders(data_args, transform_args, task_sequence, args.batch_size, frontal_lateral=model_args.frontal_lateral, normalize=model_args.normalize) class_weights = train_loader.dataset.class_weights print(" class weights:") print(class_weights) # Get loss functions uw_loss_fn = get_loss_fn('cross_entropy', args.device, model_args.model_uncertainty, args.has_tasks_missing, class_weights=class_weights) w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty, args.has_tasks_missing, mask_uncertain=False, class_weights=class_weights) # Get logger, evaluator and saver logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size, len(train_loader.dataset), args.device) eval_args = {} eval_args['num_visuals'] = logger_args.num_visuals eval_args['iters_per_eval'] = logger_args.iters_per_eval eval_args['has_missing_tasks'] = args.has_tasks_missing eval_args['model_uncertainty'] = model_args.model_uncertainty eval_args['class_weights'] = class_weights eval_args['max_eval'] = logger_args.max_eval eval_args['device'] = args.device eval_args['optimizer'] = args.optimizer evaluator = get_evaluator('classification', eval_loaders, logger, eval_args) print("Eval Loaders: %d" % len(eval_loaders)) saver = ModelSaver(**vars(logger_args)) metrics = None lr_step = 0 # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets, info_dict in train_loader: logger.start_iter() # Evaluate and save periodically metrics, curves = evaluator.evaluate(model, args.device, logger.global_step) logger.plot_metrics(metrics) metric_val = metrics.get(logger_args.metric_name, None) assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metric_val) lr_step = util.step_scheduler( lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name) # Input: [batch_size, channels, width, height] with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) unweighted_loss = uw_loss_fn(logits, targets.to(args.device)) weighted_loss = w_loss_fn(logits, targets.to( args.device)) if w_loss_fn else None logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer) optimizer.zero_grad() if args.loss_fn == 'weighted_loss': weighted_loss.backward() else: unweighted_loss.backward() optimizer.step() logger.end_iter() logger.end_epoch(metrics, optimizer)
def train(args): write_args(args) model_args = args.model_args data_args = args.data_args logger_args = args.logger_args print(f"Training {logger_args.name}") power_constraint = PowerConstraint() possible_inputs = get_md_set(model_args.md_len) channel = get_channel(data_args.channel, model_args.modelfree, data_args) model = AutoEncoder(model_args, data_args, power_constraint, channel, possible_inputs) enc_scheduler = get_scheduler(model_args.scheduler, model_args.decay, model_args.patience) dec_scheduler = get_scheduler(model_args.scheduler, model_args.decay, model_args.patience) enc_scheduler.set_model(model.trainable_encoder) dec_scheduler.set_model(model.trainable_decoder) dataset_size = data_args.batch_size * data_args.batches_per_epoch * data_args.num_epochs loader = InputDataloader(data_args.batch_size, data_args.block_length, dataset_size) loader = loader.example_generator() logger = TrainLogger(logger_args.save_dir, logger_args.name, data_args.num_epochs, logger_args.iters_per_print) saver = ModelSaver(logger_args.save_dir, logger) enc_scheduler.on_train_begin() dec_scheduler.on_train_begin() while True: # Loop until StopIteration try: metrics = None logger.start_epoch() for step in range(data_args.batches_per_epoch // (model_args.train_ratio + 1)): # encoder train logger.start_iter() msg = next(loader) metrics = model.train_encoder(msg) logger.log_iter(metrics) logger.end_iter() # decoder train for _ in range(model_args.train_ratio): logger.start_iter() msg = next(loader) metrics = model.train_decoder(msg) logger.log_iter(metrics) logger.end_iter() logger.end_epoch(None) if model_args.modelfree: model.Pi.std *= model_args.sigma_decay enc_scheduler.on_epoch_end(logger.epoch, logs=metrics) dec_scheduler.on_epoch_end(logger.epoch, logs=metrics) if logger.has_improved(): saver.save(model) if logger.notImprovedCounter >= 7: break except StopIteration: break
def train_classifier(args, model): """Train a classifier and save its first-layer weights. Args: args: Command line arguments. model: Classifier model to train. """ # Set up data loader train_loader, test_loader, classes = get_data_loaders( args.dataset, args.batch_size, args.num_workers) # Set up model model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) fd = None if args.use_fd: fd = models.filter_discriminator() fd = nn.DataParallel(fd, args.gpu_ids) fd = fd.to(args.device) # Set up optimizer optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.sgd_momentum, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_gamma) if args.model == 'fd': post_process = nn.Sigmoid() loss_fn = nn.MSELoss().to(args.device) else: post_process = nn.Sequential() # Identity loss_fn = nn.CrossEntropyLoss().to(args.device) # Set up checkpoint saver saver = ModelSaver(model, optimizer, scheduler, args.save_dir, {'model': args.model}, max_to_keep=args.max_ckpts, device=args.device) # Train logger = TrainLogger(args, len(train_loader.dataset)) if args.save_all: # Save initialized model weights with validation loss as random saver.save(0, math.log(args.num_classes)) while not logger.is_finished_training(): logger.start_epoch() # Train for one epoch model.train() fd_lambda = get_fd_lambda(args, logger.epoch) for inputs, labels in train_loader: logger.start_iter() with torch.set_grad_enabled(True): # Forward outputs = model.forward(inputs.to(args.device)) outputs = post_process(outputs) loss = loss_fn(outputs, labels.to(args.device)) loss_item = loss.item() fd_loss = torch.zeros([], dtype=torch.float32, device='cuda' if args.gpu_ids else 'cpu') tp_total = torch.zeros( [], dtype=torch.float32, device='cuda' if args.gpu_ids else 'cpu') if fd is not None: # Forward FD filters = get_layer_weights(model, filter_dict[args.model]) for i in range(0, filters.size(0), args.fd_batch_size): fd_batch = filters[i:i + args.fd_batch_size] tp_scores = F.sigmoid(fd.forward(fd_batch)) tp_total += tp_scores.sum() fd_loss = 1. - tp_total / filters.size(0) fd_loss_item = fd_loss.item() loss += fd_lambda * fd_loss # Backward optimizer.zero_grad() loss.backward() optimizer.step() logger.end_iter({ 'std_loss': loss_item, 'fd_loss': fd_loss_item, 'loss': loss_item + fd_loss_item }) # Evaluate on validation set val_loss = evaluate(model, post_process, test_loader, loss_fn, device=args.device) logger.write('[epoch {}]: val_loss: {:.3g}'.format( logger.epoch, val_loss)) logger.write_summaries({'loss': val_loss}, phase='val') if args.save_all or logger.epoch in args.save_epochs: saver.save(logger.epoch, val_loss) logger.end_epoch() scheduler.step()
def train(args): if args.ckpt_path: model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids) args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[args.model] model = model_fn(pretrained=args.pretrained) if args.pretrained: model.fc = nn.Linear(model.fc.in_features, args.num_classes) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler parameters = optim.get_parameters(model.module, args) optimizer = optim.get_optimizer(parameters, args) lr_scheduler = optim.get_scheduler(optimizer, args) if args.ckpt_path: ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler) # Get logger, evaluator, saver loss_fn = nn.CrossEntropyLoss() train_loader = WhiteboardLoader(args.data_dir, 'train', args.batch_size, shuffle=True, do_augment=True, num_workers=args.num_workers) logger = TrainLogger(args, len(train_loader.dataset)) eval_loaders = [ WhiteboardLoader(args.data_dir, 'val', args.batch_size, shuffle=False, do_augment=False, num_workers=args.num_workers) ] evaluator = ModelEvaluator(eval_loaders, logger, args.epochs_per_eval, args.max_eval, args.num_visuals) saver = ModelSaver(**vars(args)) # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets, paths in train_loader: logger.start_iter() with torch.set_grad_enabled(True): logits = model.forward(inputs.to(args.device)) loss = loss_fn(logits, targets.to(args.device)) logger.log_iter(inputs, logits, targets, paths, loss) optimizer.zero_grad() loss.backward() optimizer.step() optim.step_scheduler(lr_scheduler, global_step=logger.global_step) logger.end_iter() metrics = evaluator.evaluate(model, args.device, logger.epoch) saver.save(logger.epoch, model, args.model, optimizer, lr_scheduler, args.device, metric_val=metrics.get(args.metric_name, None)) logger.end_epoch(metrics) optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
def train(args): """Run training loop with the given args. The function consists of the following steps: 1. Load model: gets the model from a checkpoint or from models/models.py. 2. Load optimizer and learning rate scheduler. 3. Get data loaders and class weights. 4. Get loss functions: cross entropy loss and weighted loss functions. 5. Get logger, evaluator, and saver. 6. Run training loop, evaluate and save model periodically. """ model_args = args.model_args logger_args = args.logger_args optim_args = args.optim_args data_args = args.data_args transform_args = args.transform_args task_sequence = TASK_SEQUENCES[data_args.task_sequence] print('gpus: ', args.gpu_ids) # Get model if model_args.ckpt_path: model_args.pretrained = False model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args) if not logger_args.restart_epoch_count: args.start_epoch = ckpt_info['epoch'] + 1 else: model_fn = models.__dict__[model_args.model] model = model_fn(task_sequence, model_args) num_covars = len(model_args.covar_list.split(';')) model.transform_model_shape(len(task_sequence), num_covars) if model_args.hierarchy: model = models.HierarchyWrapper(model, task_sequence) model = nn.DataParallel(model, args.gpu_ids) model = model.to(args.device) model.train() # Get optimizer and scheduler optimizer = util.get_optimizer(model.parameters(), optim_args) lr_scheduler = util.get_scheduler(optimizer, optim_args) # The optimizer is loaded from the ckpt if one exists and the new model # architecture is the same as the old one (classifier is not transformed). if model_args.ckpt_path and not model_args.transform_classifier: ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler) # Get loaders and class weights train_csv_name = 'train' if data_args.uncertain_map_path is not None: train_csv_name = data_args.uncertain_map_path # Put all CXR training fractions into one dictionary and pass it to the loader cxr_frac = {'pocus': data_args.pocus_train_frac, 'hocus': data_args.hocus_train_frac, 'pulm': data_args.pulm_train_frac} train_loader = get_loader(data_args, transform_args, train_csv_name, task_sequence, data_args.su_train_frac, data_args.nih_train_frac, cxr_frac, data_args.tcga_train_frac, args.batch_size, frontal_lateral=model_args.frontal_lateral, is_training=True, shuffle=True, covar_list=model_args.covar_list, fold_num=data_args.fold_num) eval_loaders = get_eval_loaders(data_args, transform_args, task_sequence, args.batch_size, frontal_lateral=model_args.frontal_lateral, covar_list=model_args.covar_list, fold_num=data_args.fold_num) class_weights = train_loader.dataset.class_weights # Get loss functions uw_loss_fn = get_loss_fn(args.loss_fn, args.device, model_args.model_uncertainty, args.has_tasks_missing, class_weights=class_weights) w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty, args.has_tasks_missing, class_weights=class_weights) # Get logger, evaluator and saver logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size, len(train_loader.dataset), args.device, normalization=transform_args.normalization) eval_args = {} eval_args['num_visuals'] = logger_args.num_visuals eval_args['iters_per_eval'] = logger_args.iters_per_eval eval_args['has_missing_tasks'] = args.has_tasks_missing eval_args['model_uncertainty'] = model_args.model_uncertainty eval_args['class_weights'] = class_weights eval_args['max_eval'] = logger_args.max_eval eval_args['device'] = args.device eval_args['optimizer'] = optimizer evaluator = get_evaluator('classification', eval_loaders, logger, eval_args) print("Eval Loaders: %d" % len(eval_loaders)) saver = ModelSaver(**vars(logger_args)) metrics = None lr_step = 0 # Train model while not logger.is_finished_training(): logger.start_epoch() for inputs, targets, info_dict, covars in train_loader: logger.start_iter() # Evaluate and save periodically metrics, curves = evaluator.evaluate(model, args.device, logger.global_step) logger.plot_metrics(metrics) metric_val = metrics.get(logger_args.metric_name, None) assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device, metric_val=metric_val, covar_list=model_args.covar_list) lr_step = util.step_scheduler(lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name) # Input: [batch_size, channels, width, height] with torch.set_grad_enabled(True): # with torch.autograd.set_detect_anomaly(True): logits = model.forward([inputs.to(args.device), covars]) # Scale up TB so that it's loss is counted for more if upweight_tb is True. if model_args.upweight_tb is True: tb_targets = targets.narrow(1, 0, 1) findings_targets = targets.narrow(1, 1, targets.shape[1] - 1) tb_targets = tb_targets.repeat(1, targets.shape[1] - 1) new_targets = torch.cat((tb_targets, findings_targets), 1) tb_logits = logits.narrow(1, 0, 1) findings_logits = logits.narrow(1, 1, logits.shape[1] - 1) tb_logits = tb_logits.repeat(1, logits.shape[1] - 1) new_logits = torch.cat((tb_logits, findings_logits), 1) else: new_logits = logits new_targets = targets unweighted_loss = uw_loss_fn(new_logits, new_targets.to(args.device)) weighted_loss = w_loss_fn(logits, targets.to(args.device)) if w_loss_fn else None logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer) optimizer.zero_grad() if args.loss_fn == 'weighted_loss': weighted_loss.backward() else: unweighted_loss.backward() optimizer.step() logger.end_iter() logger.end_epoch(metrics, optimizer)