def main(): # Init state params params = init_parms() device = params.get('device') # Loading the model, optimizer & criterion model = ASRModel(input_features=config.num_mel_banks, num_classes=config.vocab_size).to(device) model = torch.nn.DataParallel(model) logger.info(f'Model initialized with {get_model_size(model):.3f}M parameters') optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5) load_checkpoint(model, optimizer, params) start_epoch = params['start_epoch'] sup_criterion = CustomCTCLoss() unsup_criterion = UDALoss() # Init tensorboard logger, currently gives an error on 37 server. tb_logger = TensorboardLogger(log_dir=log_path) # Validation progress bars defined here. pbar = ProgressBar(persist=True, desc="Training") pbar_valid = ProgressBar(persist=True, desc="Validation Clean") pbar_valid_other = ProgressBar(persist=True, desc="Validation Other") pbar_valid_airtel = ProgressBar(persist=True, desc="Validation Airtel") pbar_valid_airtel_payments = ProgressBar(persist=True, desc="Validation Airtel Payments") pbar_valid_airtel_hinghlish = ProgressBar(persist=True, desc="Validation Airtel Highlish") # load timer and best meter to keep track of state params timer = Timer(average=True) best_meter = params.get('best_stats', BestMeter()) # load all the train data logger.info('Begining to load Datasets') trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled') trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled') trainCommonVoicePath = os.path.join( lmdb_commonvoice_root_path, 'train-labelled-en') trainAirtelPath = os.path.join(lmdb_airtel_root_path, 'train-labelled-en') trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'train-labelled-en') trainAirtelHinglishPath = os.path.join(lmdb_airtel_hinglish_root_path, 'train-labelled-en') # test data testCleanPath = os.path.join(lmdb_root_path, 'test-clean') testOtherPath = os.path.join(lmdb_root_path, 'test-other') testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en') testAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'test-labelled-en') testAirtelHinglishPath = os.path.join(lmdb_airtel_hinglish_root_path, 'test-labelled-en') # ideally the unsupervised data here devOtherPath = os.path.join(lmdb_root_path, 'dev-other') # form data loaders train_clean = lmdbMultiDataset( roots=[trainCleanPath, trainOtherPath, trainCommonVoicePath, trainAirtelPath, trainAirtelPaymentsPath, trainAirtelHinglishPath], transform=image_train_transform) train_other = lmdbMultiDataset(roots=[devOtherPath], transform=image_train_transform) test_clean = lmdbMultiDataset(roots=[testCleanPath], transform=image_val_transform) test_other = lmdbMultiDataset(roots=[testOtherPath], transform=image_val_transform) test_airtel = lmdbMultiDataset(roots=[testAirtelPath], transform=image_val_transform) test_payments_airtel = lmdbMultiDataset(roots=[testAirtelPaymentsPath], transform=image_val_transform) test_hinglish_airtel = lmdbMultiDataset(roots=[testAirtelHinglishPath], transform=image_val_transform) logger.info( f'Loaded Train & Test Datasets, train_labbeled={len(train_clean)}, train_unlabbeled={len(train_other)}, test_clean={len(test_clean)}, test_other={len(test_other)}, test_airtel={len(test_airtel)}, test_payments_airtel={len(test_payments_airtel)}, test_hinglish_airtel={len(test_hinglish_airtel)} examples') def train_update_function(engine, _): optimizer.zero_grad() # Supervised gt, pred imgs_sup, labels_sup, label_lengths, input_lengths = next( engine.state.train_loader_labbeled) imgs_sup = imgs_sup.to(device) labels_sup = labels_sup # with torch.autograd.detect_anomaly(): probs_sup = model(imgs_sup) # Unsupervised gt, pred # imgs_unsup, augmented_imgs_unsup = next(engine.state.train_loader_unlabbeled) # with torch.no_grad(): # probs_unsup = model(imgs_unsup.to(device)) # probs_aug_unsup = model(augmented_imgs_unsup.to(device)) sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths, input_lengths) # unsup_loss = unsup_criterion(probs_unsup, probs_aug_unsup) # Blend supervised and unsupervised losses till unsupervision_warmup_epoch # alpha = get_alpha(engine.state.epoch) # final_loss = ((1 - alpha) * sup_loss) + (alpha * unsup_loss) # final_loss = sup_loss sup_loss.backward() optimizer.step() return sup_loss.item() @torch.no_grad() def validate_update_function(engine, batch): img, labels, label_lengths, image_lengths = batch y_pred = model(img.to(device)) if np.random.rand() > 0.99: pred_sentences = get_most_probable(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = sequence_to_string(labels_list[idx:idx+length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths) train_loader_labbeled_loader = torch.utils.data.DataLoader( train_clean, batch_size=train_batch_size, shuffle=True, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) train_loader_unlabbeled_loader = torch.utils.data.DataLoader( train_other, batch_size=train_batch_size * 4, shuffle=True, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) test_loader_clean = torch.utils.data.DataLoader( test_clean, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) test_loader_other = torch.utils.data.DataLoader( test_other, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) test_loader_airtel = torch.utils.data.DataLoader( test_airtel, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) test_loader_airtel_payments = torch.utils.data.DataLoader( test_payments_airtel, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) test_loader_airtel_hinglish = torch.utils.data.DataLoader( test_hinglish_airtel, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) trainer = Engine(train_update_function) evaluator_clean = Engine(validate_update_function) evaluator_other = Engine(validate_update_function) evaluator_airtel = Engine(validate_update_function) evaluator_airtel_payments = Engine(validate_update_function) evaluator_airtel_hinglish = Engine(validate_update_function) metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()} iteration_log_step = int(0.33 * len(train_loader_labbeled_loader)) for name, metric in metrics.items(): metric.attach(evaluator_clean, name) metric.attach(evaluator_other, name) metric.attach(evaluator_airtel, name) metric.attach(evaluator_airtel_payments, name) metric.attach(evaluator_airtel_hinglish, name) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=config.lr_gamma, patience=int( config.epochs * 0.05), verbose=True, threshold_mode="abs", cooldown=int(config.epochs * 0.025), min_lr=1e-5) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_clean, log_handler=OutputHandler(tag="validation_clean", metric_names=[ "wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_other, log_handler=OutputHandler(tag="validation_other", metric_names=[ "wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel, log_handler=OutputHandler(tag="validation_airtel", metric_names=[ "wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel_payments, log_handler=OutputHandler(tag="validation_airtel_payments", metric_names=[ "wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel_hinglish, log_handler=OutputHandler(tag="validation_airtel_highlish", metric_names=[ "wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) pbar.attach(trainer, output_transform=lambda x: {'loss': x}) pbar_valid.attach(evaluator_clean, [ 'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_other.attach(evaluator_other, [ 'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel.attach(evaluator_airtel, [ 'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel_payments.attach(evaluator_airtel_payments, [ 'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel_hinghlish.attach(evaluator_airtel_hinglish, [ 'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) timer.attach(trainer) @trainer.on(Events.STARTED) def set_init_epoch(engine): engine.state.epoch = params['start_epoch'] logger.info(f'Initial epoch for trainer set to {engine.state.epoch}') @trainer.on(Events.EPOCH_STARTED) def set_model_train(engine): if hasattr(engine.state, 'train_loader_labbeled'): del engine.state.train_loader_labbeled engine.state.train_loader_labbeled = iter(train_loader_labbeled_loader) # engine.state.train_loader_unlabbeled = iter(train_loader_unlabbeled_loader) @trainer.on(Events.ITERATION_COMPLETED) def iteration_completed(engine): if (engine.state.iteration % iteration_log_step == 0) and (engine.state.iteration > 0): engine.state.epoch += 1 train_clean.set_epochs(engine.state.epoch) train_other.set_epochs(engine.state.epoch) model.eval() logger.info('Model set to eval mode') evaluator_clean.run(test_loader_clean) evaluator_other.run(test_loader_other) evaluator_airtel.run(test_loader_airtel) evaluator_airtel_payments.run(test_loader_airtel_payments) evaluator_airtel_hinglish.run(test_loader_airtel_hinglish) model.train() logger.info('Model set back to train mode') @trainer.on(Events.EPOCH_COMPLETED) def after_complete(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @evaluator_other.on(Events.EPOCH_COMPLETED) def save_checkpoints(engine): metrics = engine.state.metrics wer = metrics['wer'] cer = metrics['cer'] epoch = trainer.state.epoch scheduler.step(wer) save_checkpoint(model, optimizer, best_meter, wer, cer, epoch) best_meter.update(wer, cer, epoch) trainer.run(train_loader_labbeled_loader, max_epochs=epochs) tb_logger.close()
model = C2F3(n_classes=10) dataset = MNISTDataset('train_small.csv') train_loader = DataLoader(dataset, batch_size=1024) #eval_loader = DataLoader(dataset, batch_size=1024, shuffle=False) criterion = nn.MSELoss(reduction='sum') #nn.CrossEntropyLoss() optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4) #torch.optim.SGD(model.parameters(), lr=1e-4) #%% Initialize trainer and handlers which print info after iterations and epochs trainer = create_supervised_trainer(model, optimizer, criterion) evaluator = create_supervised_evaluator(model, metrics={ #'accuracy': CategoricalAccuracy(), 'mse': Loss(criterion) }) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED) @trainer.on(Events.ITERATION_COMPLETED) def log_on_iteration_completed(trainer): print('Epoch [{}] Loss: {:.2f}'.format( trainer.state.epoch, trainer.state.output)) @trainer.on(Events.EPOCH_COMPLETED) def log_on_epoch_completed(trainer): evaluator.run(train_loader) metrics = evaluator.state.metrics #print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" # .format(trainer.state.epoch, metrics['accuracy'], metrics['mse'])) print('Epoch [{}] Time: {:.2f} sec'.format(trainer.state.epoch, timer.value()))
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn, num_query, start_epoch, device_id, train_camstyle_loader): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR epochs = cfg.SOLVER.MAX_EPOCHS device = cfg.MODEL.DEVICE logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, device_id=device_id) evaluator = create_supervised_evaluator( model, metrics={ 'r1_mAP': R1_mAP(num_query, True, False, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) }, device=device, device_id=device_id) if device_id == 0: checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') RunningAverage(output_transform=lambda x: x[2]).attach( trainer, 'data_ratio') @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): scheduler.step() def cycle(iterable): while True: for i in iterable: yield i train_loader_iter = cycle(train_loader) train_camstyle_loader_iter = cycle(train_camstyle_loader) @trainer.on(Events.ITERATION_STARTED) def generate_batch(engine): current_iter = engine.state.iteration batch = next(train_loader_iter) camstyle_batch = next(train_camstyle_loader_iter) engine.state.batch = [batch, camstyle_batch] @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, ratio of data/cam_data: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], engine.state.metrics['data_ratio'], scheduler.get_lr()[0])) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): if engine.state.epoch % eval_period == 0: evaluator.run(val_loader) cmc, mAP = evaluator.state.metrics['r1_mAP'] logger.info("Validation Results - Epoch: {}".format( engine.state.epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1])) num_iters = len(train_loader) data = list(range(num_iters)) trainer.run(data, max_epochs=epochs)
ProgressBar().attach(trainer, metric_names=metric_names) # Model checkpointing checkpoint_handler = ModelCheckpoint("./", "checkpoint", save_interval=1, n_saved=3, require_empty=False) #trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, # to_save={'model': model, 'optimizer': optimizer, # 'annealers': (sigma_scheme.data, mu_scheme.data)}) timer = Timer(average=True).attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Tensorbard writer writer = SummaryWriter(log_dir=args.log_dir) @trainer.on(Events.ITERATION_COMPLETED) def log_metrics(engine): for key, value in engine.state.metrics.items(): writer.add_scalar("training/{}".format(key), value, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def save_images(engine): print("Epoch Completed save_images")
def add_stdout_handler(trainer, validator=None): """ This adds the following handler to the trainer engine, and also sets up Timers: - log_epoch_to_stdout: This logs the results of a model after it has trained for a single epoch on both the training and validation set. The output typically looks like this: .. code-block:: none EPOCH SUMMARY ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Epoch number: 0010 / 0010 - Training loss: 0.583591 - Validation loss: 0.137209 - Epoch took: 00:00:03 - Time since start: 00:00:32 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving to test. Output @ tests/local/trainer Args: trainer (ignite.Engine): Engine for trainer validator (ignite.Engine, optional): Engine for validation. Defaults to None. """ # Set up timers for overall time taken and each epoch overall_timer = Timer(average=False) overall_timer.attach(trainer, start=Events.STARTED, pause=Events.COMPLETED) epoch_timer = Timer(average=False) epoch_timer.attach( trainer, start=Events.EPOCH_STARTED, pause=ValidationEvents.VALIDATION_COMPLETED ) @trainer.on(ValidationEvents.VALIDATION_COMPLETED) def log_epoch_to_stdout(trainer): epoch_time = epoch_timer.value() epoch_time = time.strftime( "%H:%M:%S", time.gmtime(epoch_time)) overall_time = overall_timer.value() overall_time = time.strftime( "%H:%M:%S", time.gmtime(overall_time)) epoch_number = trainer.state.epoch total_epochs = trainer.state.max_epochs try: validation_loss = ( f"{trainer.state.epoch_history['validation/loss'][-1]:04f}") except: validation_loss = 'N/A' train_loss = trainer.state.epoch_history['train/loss'][-1] saved_model_path = trainer.state.saved_model_path logging_str = ( f"\n\n" f"EPOCH SUMMARY \n" f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n" f"- Epoch number: {epoch_number:04d} / {total_epochs:04d} \n" f"- Training loss: {train_loss:04f} \n" f"- Validation loss: {validation_loss} \n" f"- Epoch took: {epoch_time} \n" f"- Time since start: {overall_time} \n" f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n" f"Saving to {saved_model_path}. \n" f"Output @ {trainer.state.output_folder} \n" ) logging.info(logging_str)
def train(run_name, forward_func, model, train_set, val_set, n_epochs, batch_size, lr): # Make the run directory save_dir = os.path.join('training/simple/saved_runs', run_name) if run_name == 'debug': shutil.rmtree(save_dir, ignore_errors=True) os.mkdir(save_dir) model = model.to(device) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Training step def step(engine, batch): model.train() if isinstance(batch, list): batch = [tensor.to(device) for tensor in batch] else: batch = batch.to(device) x_gen, x_q, _ = forward_func(model, batch) loss = F.l1_loss(x_gen, x_q) loss.backward() optimizer.step() optimizer.zero_grad() return {'L1': loss} # Trainer and metrics trainer = Engine(step) metric_names = ['L1'] RunningAverage(output_transform=lambda x: x['L1']).attach(trainer, 'L1') ProgressBar().attach(trainer, metric_names=metric_names) Timer(average=True).attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Model checkpointing checkpoint_handler = ModelCheckpoint(os.path.join(save_dir, 'checkpoints'), type(model).__name__, save_interval=1, n_saved=3, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'model': model, 'optimizer': optimizer }) # Tensorbard writer writer = SummaryWriter(log_dir=os.path.join(save_dir, 'logs')) @trainer.on(Events.ITERATION_COMPLETED) def log_metrics(engine): if engine.state.iteration % 100 == 0: for metric, value in engine.state.metrics.items(): writer.add_scalar('training/{}'.format(metric), value, engine.state.iteration) def save_images(engine, batch): x_gen, x_q, r = forward_func(model, batch) r_dim = r.shape[1] if isinstance(model, SimpleVVGQN): r = (r + 1) / 2 r = r.view(-1, 1, int(math.sqrt(r_dim)), int(math.sqrt(r_dim))) x_gen = x_gen.detach().cpu().float() r = r.detach().cpu().float() writer.add_image('representation', make_grid(r), engine.state.epoch) writer.add_image('generation', make_grid(x_gen), engine.state.epoch) writer.add_image('query', make_grid(x_q), engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): model.eval() with torch.no_grad(): batch = next(iter(val_loader)) if isinstance(batch, list): batch = [tensor.to(device) for tensor in batch] else: batch = batch.to(device) x_gen, x_q, r = forward_func(model, batch) loss = F.l1_loss(x_gen, x_q) writer.add_scalar('validation/L1', loss.item(), engine.state.epoch) save_images(engine, batch) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): writer.close() engine.terminate() if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): import warnings warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'model_exception': model}) else: raise e start_time = time.time() trainer.run(train_loader, n_epochs) writer.close() end_time = time.time() print('Total training time: {}'.format( timedelta(seconds=end_time - start_time)))
def main( dataset, dataroot, z_dim, g_filters, d_filters, batch_size, epochs, learning_rate, beta_1, saved_G, saved_D, seed, n_workers, device, alpha, output_dir, ): # seed check_manual_seed(seed) # data dataset, num_channels = check_dataset(dataset, dataroot) loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) # netowrks netG = Generator(z_dim, g_filters, num_channels).to(device) netD = Discriminator(num_channels, d_filters).to(device) # criterion bce = nn.BCELoss() # optimizers optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) # load pre-trained models if saved_G: netG.load_state_dict(torch.load(saved_G)) if saved_D: netD.load_state_dict(torch.load(saved_D)) # misc real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device) def get_noise(): return torch.randn(batch_size, z_dim, 1, 1, device=device) # The main function, processing a batch of examples def step(engine, batch): # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels. real, _ = batch real = real.to(device) # ----------------------------------------------------------- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) netD.zero_grad() # train with real output = netD(real) errD_real = bce(output, real_labels) D_x = output.mean().item() errD_real.backward() # get fake image from generator noise = get_noise() fake = netG(noise) # train with fake output = netD(fake.detach()) errD_fake = bce(output, fake_labels) D_G_z1 = output.mean().item() errD_fake.backward() # gradient update errD = errD_real + errD_fake optimizerD.step() # ----------------------------------------------------------- # (2) Update G network: maximize log(D(G(z))) netG.zero_grad() # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" output = netD(fake) errG = bce(output, real_labels) D_G_z2 = output.mean().item() errG.backward() # gradient update optimizerG.step() return {"errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2} # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False) timer = Timer(average=True) # attach running average metrics monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"] RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(trainer, "errD") RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(trainer, "errG") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_x"]).attach(trainer, "D_x") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(trainer, "D_G_z1") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(trainer, "D_G_z2") # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ)) def print_logs(engine): fname = os.path.join(output_dir, LOGS_FNAME) columns = ["iteration",] + list(engine.state.metrics.keys()) values = [str(engine.state.iteration),] + [str(round(value, 5)) for value in engine.state.metrics.values()] with open(fname, "a") as f: if f.tell() == 0: print("\t".join(columns), file=f) print("\t".join(values), file=f) message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration % len(loader)}/{len(loader)}]" for name, value in zip(columns, values): message += f" | {name}: {value}" pbar.log_message(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_fake_example(engine): fake = netG(fixed_noise) path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(fake.detach(), path, normalize=True) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_real_example(engine): img, y = engine.state.batch path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(img, path, normalize=True) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"netG": netG, "netD": netD} ) # automatically adding handlers via a special `attach` method of `Timer` handler timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message(f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]") timer.reset() # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): try: import matplotlib as mpl mpl.use("agg") import matplotlib.pyplot as plt import pandas as pd except ImportError: warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found") else: df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter="\t", index_col="iteration") _ = df.plot(subplots=True, figsize=(20, 20)) _ = plt.xlabel("Iteration number") fig = plt.gcf() path = os.path.join(output_dir, PLOT_FNAME) fig.savefig(path) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") create_plots(engine) checkpoint_handler(engine, {"netG_exception": netG, "netD_exception": netD}) else: raise e # Setup is done. Now let's run the training trainer.run(loader, epochs)
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn, num_query): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS if device == "cuda": torch.cuda.set_device(cfg.MODEL.CUDA) logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) evaluator = create_supervised_evaluator( model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device) checkpointer = ModelCheckpoint(dirname=output_dir, filename_prefix=cfg.MODEL.NAME, n_saved=None, require_empty=False) timer = Timer(average=True) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, { 'model': model, 'optimizer': optimizer }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): scheduler.step() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): if engine.state.epoch % eval_period == 0: evaluator.run(val_loader) cmc, mAP = evaluator.state.metrics['r1_mAP'] logger.info("Validation Results - Epoch: {}".format( engine.state.epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1])) trainer.run(train_loader, max_epochs=epochs)
def main(): args = parse_args() logger.info('Num GPU: {}'.format(num_gpus)) logger.info('Load Dataset') data = get_dataset(args.dataset, args.data_root, args.batch_size) data1, _ = data['train'][0] dims = list(data1.shape) param = dict( zdim=args.zdim, hdim=args.hdim, quant=not args.no_quantization, layers=args.layers, sigma=args.sigma, ) model, optimizer = get_model(args.model, args.learning_rate, param, *dims) model = torch.nn.DataParallel(model) if num_gpus > 1 else model model.to(device) logger.info(model) kwargs = { 'pin_memory': True if use_gpu else False, 'shuffle': True, 'num_workers': num_gpus * 4 } logdir = get_logdir_name(args, param) logger.info('Log Dir: {}'.format(logdir)) writer = SummaryWriter(logdir) os.makedirs(logdir, exist_ok=True) train_loader = DataLoader(data['train'], args.batch_size * num_gpus, **kwargs) kwargs['shuffle'] = True test_loader = DataLoader(data['test'], args.batch_size * num_gpus, **kwargs) if not args.no_quantization: q = Quantization(device=device) # raise NotImplementedError('It is using sigmoid now') else: q = Range() sigma_default = args.sigma * torch.ones(1, args.layers, 1, 1, 1) if use_gpu: sigma_default = sigma_default.cuda() else: sigma_default = sigma_default.cpu() def get_recon_error(x, sigma, x_mu_k, log_ms_k, recon): batch, *xdims = x.shape n = Normal(x_mu_k, sigma) log_x_mu = n.log_prob(x.view(batch, 1, *xdims)) log_mx = log_x_mu + log_ms_k ll = torch.log(log_mx.exp().sum(dim=1)) return -ll.sum(dim=[1, 2, 3]).mean() def step(engine, batch): model.train() x, _ = batch x = x.to(device) x = q.preprocess(x) recon, recon_k, x_mu_k, log_ms_k, kl_m, kl_c = model(x) nll = get_recon_error(x, sigma_default, x_mu_k, log_ms_k, recon) kl_m = kl_m.sum(dim=[1, 2, 3, 4]).mean() kl_c = kl_c.sum(dim=[1, 2, 3, 4]).mean() optimizer.zero_grad() nll_ema = engine.global_info['nll_ema'] kl_ema = engine.global_info['kl_ema'] beta = engine.global_info['beta'] nll_ema = get_ema(nll.detach(), nll_ema, args.geco_alpha) kl_ema = get_ema((kl_m + kl_c).detach(), kl_ema, args.geco_alpha) loss = nll + beta * (kl_c + kl_m) elbo = -loss loss.backward() optimizer.step() # GECO update n_pixels = x.shape[1] * x.shape[2] * x.shape[3] goal = args.geco_goal * n_pixels geco_lr = args.geco_lr beta = geco_beta_update(beta, nll_ema, goal, geco_lr, speedup=args.geco_speedup) engine.global_info['nll_ema'] = nll_ema engine.global_info['kl_ema'] = kl_ema engine.global_info['beta'] = beta lr = optimizer.param_groups[0]['lr'] ret = { 'elbo': elbo.item(), 'nll': nll.item(), 'kl_m': kl_m.item(), 'kl_c': kl_c.item(), 'lr': lr, 'sigma': args.sigma, 'beta': beta } return ret trainer = Engine(step) trainer.global_info = { 'nll_ema': None, 'kl_ema': None, 'beta': torch.tensor(args.geco_init).to(device) } metric_names = ['elbo', 'nll', 'kl_m', 'kl_c', 'lr', 'sigma', 'beta'] RunningAverage(output_transform=lambda x: x['elbo']).attach( trainer, 'elbo') RunningAverage(output_transform=lambda x: x['nll']).attach(trainer, 'nll') RunningAverage(output_transform=lambda x: x['kl_m']).attach( trainer, 'kl_m') RunningAverage(output_transform=lambda x: x['kl_c']).attach( trainer, 'kl_c') RunningAverage(output_transform=lambda x: x['lr']).attach(trainer, 'lr') RunningAverage(output_transform=lambda x: x['sigma']).attach( trainer, 'sigma') RunningAverage(output_transform=lambda x: x['beta']).attach( trainer, 'beta') ProgressBar().attach(trainer, metric_names=metric_names) Timer(average=True).attach(trainer) add_events(trainer, model, writer, logdir, args.log_interval) @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): model.eval() val_elbo = 0 val_kl_m = 0 val_kl_c = 0 val_nll = 0 beta = engine.global_info['beta'] with torch.no_grad(): for i, (x, _) in enumerate(test_loader): x = x.to(device) x_processed = q.preprocess(x) recon_processed, recon_k_processed, x_mu_k, log_ms_k, kl_m, kl_c = model( x_processed) nll = get_recon_error(x_processed, sigma_default, x_mu_k, log_ms_k, recon_processed) kl_m = kl_m.sum(dim=[1, 2, 3, 4]).mean() kl_c = kl_c.sum(dim=[1, 2, 3, 4]).mean() loss = nll + beta * (kl_m + kl_c) elbo = -loss val_elbo += elbo val_kl_m += kl_m val_kl_c += kl_c val_nll += nll if i == 0: cat = [] max_col = (args.layers + 2) for x1, mu1, mu1_k, l_k, c_k in zip( x_processed, recon_processed, recon_k_processed, log_ms_k, x_mu_k): # What a lazy way.. cat.extend([x1, mu1]) # Recon per layer cat.extend(mu1_k) cat.extend(x1.new_zeros([2, 3, 64, 64])) # Masks per layer cat.extend( q.preprocess(l_k.exp().expand( args.layers, 3, 64, 64))) cat.extend(x1.new_zeros([2, 3, 64, 64])) # components per layer cat.extend(c_k) if len(cat) > (max_col * 7 * 3): break cat = torch.stack(cat) #if cat.shape[0] > max_col * 3: # cat = cat[:max_col * 3] cat = q.postprocess(cat) writer.add_image( '{}/layers'.format(args.dataset), make_grid(cat.detach().cpu(), nrow=max_col), engine.state.iteration) val_elbo /= len(test_loader) val_kl_m /= len(test_loader) val_kl_c /= len(test_loader) val_nll /= len(test_loader) writer.add_scalar('val/elbo', val_elbo.item(), engine.state.iteration) writer.add_scalar('val/beta', beta.item(), engine.state.iteration) writer.add_scalar('val/kl_m', val_kl_m.item(), engine.state.iteration) writer.add_scalar('val/kl_c', val_kl_c.item(), engine.state.iteration) writer.add_scalar('val/nll', val_nll.item(), engine.state.iteration) print('{:3d} /{:3d} : ELBO: {:.4f}, KL-M: {:.4f}, ' 'KL-C: {:.4f} NLL: {:.4f}'.format(engine.state.epoch, engine.state.max_epochs, val_elbo, val_kl_m, val_kl_c, val_nll)) @trainer.on(Events.EXCEPTION_RAISED) def handler_exception(engine, e): writer.close() engine.terminate() if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): logger.warn('KeyboardInterrupt caught. Exiting gracefully.') else: raise e logger.info( 'Start training. Max epoch = {}, Batch = {}, # Trainset = {}'.format( args.epoch, args.batch_size, len(data['train']))) trainer.run(train_loader, args.epoch) logger.info('Done training') writer.close()
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, gpuid): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:'+str(gpuid) check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: lr * min(1., epoch / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) # set logging option logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(message)s') console = logging.StreamHandler() console.setFormatter(formatter) logger.addHandler(console) hdlr = logging.FileHandler(output_dir + 'losses.log') hdlr.setFormatter(formatter) logger.addHandler(hdlr) writer = SummaryWriter(output_dir + 'losses') def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'model': model, 'optimizer': optimizer}) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach(trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach(trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach(evaluator, 'nll') monitoring_metrics.extend(['loss_classes']) RunningAverage(output_transform=lambda x: x['loss_classes']).attach(trainer, 'loss_classes') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['loss_classes'], torch.empty(x['loss_classes'].shape[0]))).attach(evaluator, 'loss_classes') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len(engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics writer.add_scalars('loss_classes', {'train': metrics['loss_classes']}, engine.state.epoch) writer.add_scalars('loss_nll', {'train': metrics['nll']}, engine.state.epoch) writer.add_scalars('loss_total', {'train': metrics['total_loss']}, engine.state.epoch) epoch_str = ("Epoch %d. Train_Loss_Classes: %f, Train_NLL: %f, Train_Total: %f " % (engine.state.epoch, metrics['loss_classes'], metrics['nll'], metrics['total_loss'])) logging.info(epoch_str) evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()]) writer.add_scalars('loss_classes', {'eval': metrics['loss_classes']}, engine.state.epoch) writer.add_scalars('loss_nll', {'eval': metrics['nll']}, engine.state.epoch) writer.add_scalars('loss_total', {'eval': metrics['total_loss']}, engine.state.epoch) epoch_str = ("Epoch %d. Eval_Loss_Classes: %f, Eval_NLL: %f, Eval_Total: %f " % (engine.state.epoch, metrics['loss_classes'], metrics['nll'], metrics['total_loss'])) logging.info(epoch_str) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message(f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]') timer.reset() trainer.run(train_loader, epochs)
def run_once(self, fold_idx): log_dir = self.log_dir check_manual_seed(self.seed) train_pairs, valid_pairs = getattr( dataset, ('prepare_%s_data' % self.dataset))() print(len(train_pairs)) print(len(valid_pairs)) train_augmentors = self.train_augmentors() train_dataset = dataset.DatasetSerial( train_pairs[:], shape_augs=iaa.Sequential(train_augmentors[0]), input_augs=iaa.Sequential(train_augmentors[1])) infer_augmentors = self.infer_augmentors() # HACK at has_aux infer_dataset = dataset.DatasetSerial(valid_pairs[:], shape_augs=iaa.Sequential( infer_augmentors[0])) train_loader = data.DataLoader(train_dataset, num_workers=self.nr_procs_train, batch_size=self.train_batch_size, shuffle=True, drop_last=True) valid_loader = data.DataLoader(infer_dataset, num_workers=self.nr_procs_valid, batch_size=self.infer_batch_size, shuffle=True, drop_last=False) # --------------- Training Sequence if self.logging: check_log_dir(log_dir) device = 'cuda' # networksv input_chs = 3 # TODO: dynamic config net = EfficientNet.from_pretrained("efficientnet-b2", num_classes=2) # load pre-trained models if self.load_network: net = load_weight(net, self.save_net_path) net = torch.nn.DataParallel(net).to(device) # optimizers optimizer = optim.Adam(net.parameters(), lr=self.init_lr) #scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps) # trainer = Engine(lambda engine, batch: self.train_step( net, batch, optimizer, device)) valider = Engine( lambda engine, batch: self.infer_step(net, batch, device)) infer_output = ['prob', 'true'] ## if self.logging: checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, save_interval=1, n_saved=30, require_empty=False) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) timer.attach(valider, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # attach running average metrics computation # decay of EMA to 0.95 to match tensorpack default # TODO: refactor this RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach( trainer, 'acc') RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach( trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) pbar.attach(valider) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'net_exception': net}) else: raise e # writer for tensorboard logging tfwriter = None # HACK temporary if self.logging: tfwriter = SummaryWriter(log_dir) json_log_file = log_dir + '/stats.json' with open(json_log_file, 'w') as json_file: json.dump({}, json_file) # create empty file ### TODO refactor again log_info_dict = { 'logging': self.logging, 'optimizer': optimizer, 'tfwriter': tfwriter, 'json_file': json_log_file, 'nr_classes': self.nr_classes, 'metric_names': infer_output, 'infer_batch_size': self.infer_batch_size # too cumbersome } # trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: scheduler.step()) # to change the lr trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_ema_results, log_info_dict) trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider, valid_loader, log_info_dict) valider.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs) # Setup is done. Now let's run the training trainer.run(train_loader, self.nr_epochs) return
def attach_decorators(trainer, SR, feature_extractor, domain_classifier, resolution_classifier, sr_classif_critic, optim, loader): timer = Timer(average=True) checkpoint_handler = ModelCheckpoint( args.output_dir + '/checkpoints/domain_adaptation_training/', 'training', save_interval=1, n_saved=300, require_empty=False, iteration=args.epoch_c) monitoring_metrics = [ 'tgt_loss', 'src_loss', 'sr_loss', 'loss', 'GP', 'res_down_loss', 'res_up_loss', 'tv_loss', 'vgg_loss' ] RunningAverage(alpha=0.98, output_transform=lambda x: x['tgt_loss']).attach( trainer, 'tgt_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['src_loss']).attach( trainer, 'src_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['sr_loss']).attach( trainer, 'sr_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['loss']).attach( trainer, 'loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['GP']).attach(trainer, 'GP') # RunningAverage(alpha=0.98, output_transform=lambda x: x['g_loss']).attach(trainer, 'g_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['res_down_loss']).attach( trainer, 'res_down_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['res_up_loss']).attach( trainer, 'res_up_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['tv_loss']).attach( trainer, 'tv_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['vgg_loss']).attach( trainer, 'vgg_loss') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'feature_extractor': feature_extractor, 'SR': SR, # 'optim_feature': optim_feature, # 'optim_domain_classif': optim_domain_classif, # 'optim_res_classif': optim_res_classif, 'optim': optim, # 'optim_sr_critic': optim_sr_critic, 'domain_D': domain_classifier, 'res_D': resolution_classifier, 'sr_D': sr_classif_critic }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: fname = os.path.join(args.output_dir, LOGS_FNAME) columns = engine.state.metrics.keys() values = [ str(round(value, 5)) for value in engine.state.metrics.values() ] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) i = (engine.state.iteration % len(loader)) message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=args.epochs, i=i, max_i=len(loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) @trainer.on(Events.ITERATION_COMPLETED) def save_real_example(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: if (engine.state.iteration - 1) % PRINT_FREQ == 0: if not os.path.exists(args.output_dir + '/imgs/domain_adaptation_training/'): os.makedirs(args.output_dir + '/imgs/domain_adaptation_training/') px, py, px2, py2, px_up, _, px2_up, _ = engine.state.batch img = SR(feature_extractor(px2.cuda())) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', predtgt_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(img, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', targetY_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(py2, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', targetX_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(px2, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', sourceX_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(px, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', sourceY_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(py, path) img = SR(feature_extractor(px.cuda())) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', predsrc_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(img, path) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler( engine, { 'feature_extractor_{}'.format(engine.state.iteration): feature_extractor, 'SR_{}'.format(engine.state.iteration): SR, 'DOMAIN_D_{}'.format(engine.state.iteration): domain_classifier, 'RES_D_{}'.format(engine.state.iteration): resolution_classifier, 'SR_D_{}'.format(engine.state.iteration): sr_classif_critic, 'OPTIM_{}'.format(engine.state.iteration): optim }) else: raise e @trainer.on(Events.STARTED) def loaded(engine): if args.epoch_c != 0: engine.state.epoch = args.epoch_c engine.state.iteration = args.epoch_c * len(loader)
def __init__(self): self._dataflow_timer = Timer() self._processing_timer = Timer() self._event_handlers_timer = Timer()
def main(): parser = argparse.ArgumentParser() parser.add_argument("data_directory", type=Path) parser.add_argument("--generator-weights", type=Path) parser.add_argument("--discriminator-weights", type=Path) args = parser.parse_args() generator = Generator(GENERATOR_FILTERS) if args.generator_weights is not None: LOGGER.info(f"Loading generator weights: {args.generator_weights}") generator.load_state_dict(torch.load(args.generator_weights)) else: generator.weight_init(mean=0.0, std=0.02) discriminator = Discriminator(DISCRIMINATOR_FILTERS) if args.discriminator_weights is not None: LOGGER.info( f"Loading discriminator weights: {args.discriminator_weights}") discriminator.load_state_dict(torch.load(args.discriminator_weights)) else: discriminator.weight_init(mean=0.0, std=0.02) dataset = XView2Dataset(args.data_directory, ) train_dataset, test_dataset = torch.utils.data.random_split( dataset, [len(dataset) - 10, 10]) # Create a dev train dataset with just 10 samples # train_dataset, _ = torch.utils.data.random_split(train_dataset, [10, len(train_dataset) - 10]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE) generator.cuda() discriminator.cuda() generator.train() discriminator.train() BCE_loss = nn.BCELoss().cuda() L1_loss = nn.L1Loss().cuda() generator_optimizer = optim.Adam(generator.parameters(), lr=GENERATOR_LR, betas=(BETA_1, BETA_2)) discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=DISCRIMINATOR_LR, betas=(BETA_1, BETA_2)) def step(engine, batch): x, y = batch x = x.cuda() y = y.cuda() discriminator.zero_grad() discriminator_result = discriminator(x, y).squeeze() discriminator_real_loss = BCE_loss( discriminator_result, torch.ones(discriminator_result.size()).cuda()) generator_result = generator(x) discriminator_result = discriminator(x, generator_result).squeeze() discriminator_fake_loss = BCE_loss( discriminator_result, torch.zeros(discriminator_result.size()).cuda()) discriminator_train_loss = (discriminator_real_loss + discriminator_fake_loss) * 0.5 discriminator_train_loss.backward() discriminator_optimizer.step() generator.zero_grad() generator_result = generator(x) # TODO Work out if the below time saving technique impacts training. #generator_result = generator_result.detach() discriminator_result = discriminator(x, generator_result).squeeze() l1_loss = L1_loss(generator_result, y) bce_loss = BCE_loss(discriminator_result, torch.ones(discriminator_result.size()).cuda()) G_train_loss = bce_loss + L1_LAMBDA * l1_loss G_train_loss.backward() generator_optimizer.step() return { 'generator_train_loss': G_train_loss.item(), 'discriminator_real_loss': discriminator_real_loss.item(), 'discriminator_fake_loss': discriminator_fake_loss.item(), } trainer = Engine(step) tb_logger = TensorboardLogger(log_dir=f"tensorboard/logdir/{uuid4()}") tb_logger.attach(trainer, log_handler=OutputHandler( tag="training", output_transform=lambda out: out, metric_names='all'), event_name=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def add_generated_images(engine): def min_max(image): return (image - image.min()) / (image.max() - image.min()) for idx, (x, y) in enumerate(test_loader): generated = min_max(generator(x.cuda()).squeeze().cpu()) real = min_max(y.squeeze()) tb_logger.writer.add_image( f"generated_test_image_{idx}", # Concatenate the images into a single tiled image torch.cat([x.squeeze(), generated, real], 2), global_step=engine.state.epoch) checkpoint_handler = ModelCheckpoint("checkpoints/", "pix2pix", n_saved=1, require_empty=False, save_interval=1) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'generator': generator, 'discriminator': discriminator }) timer = Timer(average=True) timer.attach(trainer, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): print("Epoch[{}] Iteration[{}] Duration[{}] Losses: {}".format( engine.state.epoch, engine.state.iteration, timer.value(), engine.state.output)) trainer.run(train_loader, max_epochs=TRAIN_EPOCHS) tb_logger.close()
def do_train_with_feat(cfg, train_loader, valid, tr_comp: TrainComponent, saver): tb_log = TensorBoardXLog(cfg, saver.save_dir) device = cfg.MODEL.DEVICE trainer = create_supervised_trainer( tr_comp.model, tr_comp.optimizer, tr_comp.loss, device=device, apex=cfg.APEX.IF_ON, has_center=cfg.LOSS.IF_WITH_CENTER, center_criterion=tr_comp.loss.center, optimizer_center=tr_comp.optimizer_center, center_loss_weight=cfg.LOSS.CENTER_LOSS_WEIGHT) saver.to_save = { 'trainer': trainer, 'module': tr_comp.model, 'optimizer': tr_comp.optimizer, 'center_param': tr_comp.loss_center, 'optimizer_center': tr_comp.optimizer_center } trainer.add_event_handler( Events.EPOCH_COMPLETED(every=cfg.SAVER.CHECKPOINT_PERIOD), saver.train_checkpointer, saver.to_save) # multi-valid-dataset validation_evaluator_map = get_valid_eval_map(cfg, device, tr_comp.model, valid) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer names = ["Acc", "Loss"] names.extend(tr_comp.loss.loss_function_map.keys()) for n in names: RunningAverage(output_transform=Run(n)).attach(trainer, n) # TODO start epoch @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = 0 @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): tr_comp.scheduler.step() @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD)) def log_training_loss(engine): message = f"Epoch[{engine.state.epoch}], " + \ f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \ f"Lr: {tr_comp.scheduler.get_lr()[0]:.2e}, " + \ f"Loss: {engine.state.metrics['Loss']:.4f}, " + \ f"Acc: {engine.state.metrics['Acc']:.4f}, " for loss_name in tr_comp.loss.loss_function_map.keys(): message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, " message += f"feat: {engine.state.metrics['feat']:.4f}" logger.info(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 80) timer.reset() @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD), saver=saver) def log_validation_results(engine, saver): # train_evaluator.run(train_loader) # cmc, mAP = validation_evaluator.state.metrics['r1_mAP'] # logger.info("Train Results - Epoch: {}".format(engine.state.epoch)) # logger.info("mAP: {:.1%}".format(mAP)) # for r in [1, 5, 10]: # logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) logger.info(f"Valid - Epoch: {engine.state.epoch}") sum_result = eval_multi_dataset(device, validation_evaluator_map, valid) if saver.best_result < sum_result: logger.info(f'Save best: {sum_result:.4f}') saver.save_best_value(sum_result) saver.best_checkpointer(engine, saver.to_save) saver.best_result = sum_result else: logger.info( f"Not best: {saver.best_result:.4f} > {sum_result:.4f}") logger.info('-' * 80) tb_log.attach_handler(trainer, tr_comp.model, tr_comp.optimizer) # self.tb_logger.attach( # validation_evaluator, # log_handler=ReIDOutputHandler(tag="valid", metric_names=["r1_mAP"], another_engine=trainer), # event_name=Events.EPOCH_COMPLETED, # ) trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS) tb_log.close()
def run(config, plx_experiment): set_seed(config['seed']) device = "cuda" batch_size = config['batch_size'] cutout_size = config['cutout_size'] train_transforms = [ DynamicCrop(32, 32), FlipLR(), DynamicCutout(cutout_size, cutout_size) ] train_loader, test_loader = get_fast_train_test_loaders( path=config["data_path"], batch_size=batch_size, num_workers=config['num_workers'], device=device, train_transforms=train_transforms) bn_kwargs = config['bn_kwargs'] conv_kwargs = config['conv_kwargs'] model = FastResRecNet(conv_kwargs=conv_kwargs, bn_kwargs=bn_kwargs, final_weight=config['final_weight']) model = model.to(device) model = model.half() model_name = model.__class__.__name__ criterion = nn.CrossEntropyLoss(reduction='sum').to(device) criterion = criterion.half() eval_criterion = criterion if config["enable_mixup"]: criterion = MixupCriterion(criterion) weight_decay = config['weight_decay'] if not config['use_adam']: opt_kwargs = [("lr", 0.0), ("momentum", config['momentum']), ("weight_decay", weight_decay), ("nesterov", True)] optimizer_cls = optim.SGD else: opt_kwargs = [ ("lr", 0.0), ("betas", (0.9, 0.999)), ("eps", 1e-08), ("amsgrad", True), ("weight_decay", weight_decay), ] optimizer_cls = optim.Adam optimizer = optimizer_cls([ # conv + bn dict([("params", model.prep.parameters())] + opt_kwargs), # conv + bn dict([("params", model.layer1[0].parameters())] + opt_kwargs), # identity residual recurrent blocks dict([("params", model.layer1[-1].conv_rec.parameters())] + opt_kwargs), # conv + bn dict([("params", model.layer2[0].parameters())] + opt_kwargs), # identity residual recurrent blocks dict([("params", model.layer2[-1].conv_rec.parameters())] + opt_kwargs), # conv + bn dict([("params", model.layer3[0].parameters())] + opt_kwargs), # identity residual recurrent blocks dict([("params", model.layer3[-1].conv_rec.parameters())] + opt_kwargs), # linear dict([("params", model.classifier.parameters())] + opt_kwargs), ]) num_iterations_per_epoch = len(train_loader) num_iterations = num_iterations_per_epoch * config['num_epochs'] layerwise_milestones_lr_values = [] for i in range(len(optimizer.param_groups)): key = "lr_param_group_{}".format(i) assert key in config, "{} not in config".format(key) milestones_values = config[key] layerwise_milestones_lr_values.append([(m * num_iterations_per_epoch, v / batch_size) for m, v in milestones_values]) lr_scheduler = get_layerwise_lr_scheduler(optimizer, layerwise_milestones_lr_values) def _prepare_batch_fp16(batch, device, non_blocking): x, y = batch return (convert_tensor(x, device=device, non_blocking=non_blocking).half(), convert_tensor(y, device=device, non_blocking=non_blocking).long()) def process_function(engine, batch): x, y = _prepare_batch_fp16(batch, device=device, non_blocking=True) if config['enable_mixup']: x, y = mixup_data(x, y, config['mixup_alpha'], config['mixup_proba']) optimizer.zero_grad() y_pred = model(x) loss = criterion(y_pred, y) loss.backward() if config["clip_gradients"] is not None: clip_grad_norm_(model.parameters(), config["clip_gradients"]) optimizer.step() loss = loss.item() return loss trainer = Engine(process_function) metrics = { "accuracy": Accuracy(), "loss": Loss(eval_criterion) / len(test_loader) } evaluator = create_supervised_evaluator(model, metrics, prepare_batch=_prepare_batch_fp16, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator( model, metrics, prepare_batch=_prepare_batch_fp16, device=device, non_blocking=True) total_timer = Timer(average=False) train_timer = Timer(average=False) test_timer = Timer(average=False) table_logger = TableLogger() if config["use_tb_logger"]: path = "experiments/tb_logs" if "TB_LOGGER_PATH" not in os.environ else os.environ[ "TB_LOGGER_PATH"] tb_logger = SummaryWriter(log_dir=path) test_timer.attach(evaluator, start=Events.EPOCH_STARTED) @trainer.on(Events.STARTED) def on_training_started(engine): print("Warming up cudnn on random inputs") for _ in range(5): for size in [batch_size, len(test_loader.dataset) % batch_size]: warmup_cudnn(model, criterion, size, config) total_timer.reset() @trainer.on(Events.EPOCH_STARTED) def on_epoch_started(engine): model.train() train_timer.reset() # Warm-up on small images if config['warmup_on_small_images']: if engine.state.epoch < config['warmup_duration']: train_loader.dataset.transforms[0].h = 20 train_loader.dataset.transforms[0].w = 20 elif engine.state.epoch == config['warmup_duration']: train_loader.dataset.transforms[0].h = 32 train_loader.dataset.transforms[0].w = 32 train_loader.dataset.set_random_choices() if config['reduce_cutout']: # after 15 epoch remove cutout augmentation if 14 <= engine.state.epoch < 16: train_loader.dataset.transforms[-1].h -= 1 train_loader.dataset.transforms[-1].w -= 1 elif engine.state.epoch == 16: train_loader.dataset.transforms.pop() if config['enable_mixup'] and config[ 'mixup_max_epochs'] == engine.state.epoch - 1: config['mixup_proba'] = 0.0 if config["use_tb_logger"]: @trainer.on(Events.ITERATION_COMPLETED) def on_iteration_completed(engine): # log learning rate param_name = "lr" if len(optimizer.param_groups) == 1: param = float(optimizer.param_groups[0][param_name]) tb_logger.add_scalar(param_name, param * batch_size, engine.state.iteration) else: for i, param_group in enumerate(optimizer.param_groups): param = float(param_group[param_name]) tb_logger.add_scalar( "{}/{}/group_{}".format(param_name, model_name, i), param * batch_size, engine.state.iteration) # log training loss tb_logger.add_scalar("training/loss_vs_iterations", engine.state.output / batch_size, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def on_epoch_completed(engine): trainer.state.train_time = train_timer.value() if config["use_tb_logger"]: # Log |w|^2 and gradients for i, p in enumerate(model.parameters()): tb_logger.add_scalar( "w2/{}/{}_{}".format(model_name, i, list(p.data.shape)), torch.norm(p.data), engine.state.epoch) tb_logger.add_scalar( "mean_grad/{}/{}_{}".format(model_name, i, list(p.grad.shape)), torch.mean(p.grad), engine.state.epoch) for i, p in enumerate(model.parameters()): plx_experiment.log_metrics( step=engine.state.epoch, **{ "w2/{}/{}_{}".format(model_name, i, list(p.data.shape)): torch.norm(p.data).item() }) evaluator.run(test_loader) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) @evaluator.on(Events.COMPLETED) def log_results(engine): evaluator.state.test_time = test_timer.value() metrics = evaluator.state.metrics output = [("epoch", trainer.state.epoch)] output += [(key, trainer.state.param_history[key][-1][0] * batch_size) for key in trainer.state.param_history] output += [("train time", trainer.state.train_time), ("train loss", trainer.state.output / batch_size), ("test time", evaluator.state.test_time), ("test loss", metrics['loss'] / batch_size), ("test acc", metrics['accuracy']), ("total time", total_timer.value())] output = OrderedDict(output) table_logger.append(output) plx_experiment.log_metrics(step=trainer.state.epoch, **output) if config["use_tb_logger"]: tb_logger.add_scalar("training/total_time", total_timer.value(), trainer.state.epoch) tb_logger.add_scalar("test/loss", metrics['loss'] / batch_size, trainer.state.epoch) tb_logger.add_scalar("test/accuracy", metrics['accuracy'], trainer.state.epoch) @trainer.on(Events.COMPLETED) def on_training_completed(engine): if config["use_tb_logger"]: train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics tb_logger.add_scalar("training/loss", metrics['loss'] / batch_size, 0) tb_logger.add_scalar("training/loss", metrics['loss'] / batch_size, trainer.state.epoch) tb_logger.add_scalar("training/accuracy", metrics['accuracy'], 0) tb_logger.add_scalar("training/accuracy", metrics['accuracy'], trainer.state.epoch) trainer.run(train_loader, max_epochs=config['num_epochs']) if config["use_tb_logger"]: tb_logger.close()
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("template_model.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) evaluator = create_supervised_evaluator(model, metrics={ "accuracy": Accuracy(), "ce_loss": Loss(loss_fn) }, device=device) checkpointer = ModelCheckpoint(output_dir, "mnist", checkpoint_period, n_saved=10, require_empty=False) timer = Timer(average=True) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpointer, { "model": model.state_dict(), "optimizer": optimizer.state_dict() }, ) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) RunningAverage(output_transform=lambda x: x).attach(trainer, "avg_loss") @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format( engine.state.epoch, iter, len(train_loader), engine.state.metrics["avg_loss"], )) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_loss = metrics["ce_loss"] logger.info( "Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" .format(engine.state.epoch, avg_accuracy, avg_loss)) if val_loader is not None: @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_loss = metrics["ce_loss"] logger.info( "Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" .format(engine.state.epoch, avg_accuracy, avg_loss)) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" .format( engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value(), )) timer.reset() trainer.run(train_loader, max_epochs=epochs)
def engine_train(cfg): prepare_config_train(cfg) ckpt_nets = cfg.train.general.ckpt_nets ckpt_path = cfg.train.general.ckpt_path epochs = cfg.train.solver.epochs gpu = cfg.general.gpu lr = cfg.train.solver.lr lr_gamma = cfg.train.solver.lr_gamma lr_step = cfg.train.solver.lr_step optim = cfg.train.solver.optim renderer_lr = cfg.train.solver.renderer_lr root_path = cfg.log.root_path save_freq = cfg.train.solver.save_freq seed = cfg.general.seed eu.redirect_stdout(root_path, 'train') eu.print_config(cfg) eu.seed_random(seed) device = eu.get_device(gpu) dataloader = get_dataloader_train(cfg) num_batches = len(dataloader) render_model, desc_model = get_models(cfg) render_model.to(device) render_model.train_mode() render_model.print_params('render_model') desc_model.to(device) desc_model.train_mode() desc_model.print_params('desc_model') crit = get_criterion(cfg) print('[*] Loss Function:', crit.__class__.__name__) render_params = render_model.params(True, named=False, add_prefix=False) render_optimizer = eu.get_optimizer(optim, renderer_lr, render_params) render_lr_scheduler = eu.get_lr_scheduler(lr_step, lr_gamma, render_optimizer) desc_params = desc_model.params(True, named=False, add_prefix=False) desc_optimizer = eu.get_optimizer(optim, lr, desc_params) desc_lr_scheduler = eu.get_lr_scheduler(lr_step, lr_gamma, desc_optimizer) if eu.is_not_empty(ckpt_path): render_model.load(ckpt_path, ckpt_nets) desc_model.load(ckpt_path, ckpt_nets) tbwriter = eu.get_tbwriter(root_path) engine = Engine( functools.partial(step_train, render_model=render_model, desc_model=desc_model, render_optimizer=render_optimizer, desc_optimizer=desc_optimizer, criterion=crit, tbwriter=tbwriter, device=device, cfg=cfg)) engine.add_event_handler(Events.EPOCH_COMPLETED, eu.step_lr_scheduler, scheduler=render_lr_scheduler) engine.add_event_handler(Events.EPOCH_COMPLETED, eu.step_lr_scheduler, scheduler=desc_lr_scheduler) ckpt_handler = ModelCheckpoint(root_path, eu.PTH_PREFIX, atomic=False, save_interval=save_freq, n_saved=epochs // save_freq, require_empty=False) render_subnets = render_model.subnet_dict() desc_subnets = desc_model.subnet_dict() engine.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler, to_save={ **render_subnets, **desc_subnets }) timer = Timer(average=True) timer.attach(engine, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) engine.add_event_handler(Events.ITERATION_COMPLETED, eu.print_train_log, timer=timer, num_batches=num_batches, cfg=cfg) engine.add_event_handler(Events.EXCEPTION_RAISED, eu.handle_exception) engine.run(dataloader, epochs) tbwriter.close() return root_path
def train_model( name="", resume="", base_dir=utils.BASE_DIR, model_name="v0", chosen_diseases=None, n_epochs=10, batch_size=4, oversample=False, max_os=None, shuffle=False, opt="sgd", opt_params={}, loss_name="wbce", loss_params={}, train_resnet=False, log_metrics=None, flush_secs=120, train_max_images=None, val_max_images=None, test_max_images=None, experiment_mode="debug", save=True, save_cms=True, # Note that in this case, save_cms (to disk) includes write_cms (to TB) write_graph=False, write_emb=False, write_emb_img=False, write_img=False, image_format="RGB", multiple_gpu=False, ): # Choose GPU device = utilsT.get_torch_device() print("Using device: ", device) # Common folders dataset_dir = os.path.join(base_dir, "dataset") # Dataset handling print("Loading train dataset...") train_dataset, train_dataloader = utilsT.prepare_data( dataset_dir, "train", chosen_diseases, batch_size, oversample=oversample, max_os=max_os, shuffle=shuffle, max_images=train_max_images, image_format=image_format, ) train_samples, _ = train_dataset.size() print("Loading val dataset...") val_dataset, val_dataloader = utilsT.prepare_data( dataset_dir, "val", chosen_diseases, batch_size, max_images=val_max_images, image_format=image_format, ) val_samples, _ = val_dataset.size() # Should be the same than chosen_diseases chosen_diseases = list(train_dataset.classes) print("Chosen diseases: ", chosen_diseases) if resume: # Load model and optimizer model, model_name, optimizer, opt, loss_name, loss_params, chosen_diseases = models.load_model( base_dir, resume, experiment_mode="", device=device) model.train(True) else: # Create model model = models.init_empty_model(model_name, chosen_diseases, train_resnet=train_resnet).to(device) # Create optimizer OptClass = optimizers.get_optimizer_class(opt) optimizer = OptClass(model.parameters(), **opt_params) # print("OPT: ", opt_params) # Allow multiple GPUs if multiple_gpu: model = DataParallel(model) # Tensorboard log options run_name = utils.get_timestamp() if name: run_name += "_{}".format(name) if len(chosen_diseases) == 1: run_name += "_{}".format(chosen_diseases[0]) elif len(chosen_diseases) == 14: run_name += "_all" log_dir = get_log_dir(base_dir, run_name, experiment_mode=experiment_mode) print("Run name: ", run_name) print("Saved TB in: ", log_dir) writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs) # Create validator engine validator = Engine( utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params, False)) val_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1) val_loss.attach(validator, loss_name) utilsT.attach_metrics(validator, chosen_diseases, "prec", Precision, True) utilsT.attach_metrics(validator, chosen_diseases, "recall", Recall, True) utilsT.attach_metrics(validator, chosen_diseases, "acc", Accuracy, True) utilsT.attach_metrics(validator, chosen_diseases, "roc_auc", utilsT.RocAucMetric, False) utilsT.attach_metrics(validator, chosen_diseases, "cm", ConfusionMatrix, get_transform_fn=utilsT.get_transform_cm, metric_args=(2, )) utilsT.attach_metrics(validator, chosen_diseases, "positives", RunningAverage, get_transform_fn=utilsT.get_count_positives) # Create trainer engine trainer = Engine( utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params, True)) train_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1) train_loss.attach(trainer, loss_name) utilsT.attach_metrics(trainer, chosen_diseases, "acc", Accuracy, True) utilsT.attach_metrics(trainer, chosen_diseases, "prec", Precision, True) utilsT.attach_metrics(trainer, chosen_diseases, "recall", Recall, True) utilsT.attach_metrics(trainer, chosen_diseases, "roc_auc", utilsT.RocAucMetric, False) utilsT.attach_metrics(trainer, chosen_diseases, "cm", ConfusionMatrix, get_transform_fn=utilsT.get_transform_cm, metric_args=(2, )) utilsT.attach_metrics(trainer, chosen_diseases, "positives", RunningAverage, get_transform_fn=utilsT.get_count_positives) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED) # TODO: Early stopping # def score_function(engine): # val_loss = engine.state.metrics[loss_name] # return -val_loss # handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) # validator.add_event_handler(Events.COMPLETED, handler) # Metrics callbacks if log_metrics is None: log_metrics = list(ALL_METRICS) def _write_metrics(run_type, metrics, epoch, wall_time): loss = metrics.get(loss_name, 0) writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time) for metric_base_name in log_metrics: for disease in chosen_diseases: metric_value = metrics.get( "{}_{}".format(metric_base_name, disease), -1) writer.add_scalar( "{}_{}/{}".format(metric_base_name, disease, run_type), metric_value, epoch, wall_time) @trainer.on(Events.EPOCH_COMPLETED) def tb_write_metrics(trainer): epoch = trainer.state.epoch max_epochs = trainer.state.max_epochs # Run on evaluation validator.run(val_dataloader, 1) # Common time wall_time = time.time() # Log all metrics to TB _write_metrics("train", trainer.state.metrics, epoch, wall_time) _write_metrics("val", validator.state.metrics, epoch, wall_time) train_loss = trainer.state.metrics.get(loss_name, 0) val_loss = validator.state.metrics.get(loss_name, 0) tb_write_histogram(writer, model, epoch, wall_time) print("Finished epoch {}/{}, loss {:.3f}, val loss {:.3f} (took {})". format(epoch, max_epochs, train_loss, val_loss, utils.duration_to_str(int(timer._elapsed())))) # Hparam dict hparam_dict = { "resume": resume, "n_diseases": len(chosen_diseases), "diseases": ",".join(chosen_diseases), "n_epochs": n_epochs, "batch_size": batch_size, "shuffle": shuffle, "model_name": model_name, "opt": opt, "loss": loss_name, "samples (train, val)": "{},{}".format(train_samples, val_samples), "train_resnet": train_resnet, "multiple_gpu": multiple_gpu, } def copy_params(params_dict, base_name): for name, value in params_dict.items(): hparam_dict["{}_{}".format(base_name, name)] = value copy_params(loss_params, "loss") copy_params(opt_params, "opt") print("HPARAM: ", hparam_dict) # Train print("-" * 50) print("Training...") trainer.run(train_dataloader, n_epochs) # Capture time secs_per_epoch = timer.value() duration_per_epoch = utils.duration_to_str(int(secs_per_epoch)) print("Average time per epoch: ", duration_per_epoch) print("-" * 50) ## Write all hparams hparam_dict["duration_per_epoch"] = duration_per_epoch # FIXME: this is commented to avoid having too many hparams in TB frontend # metrics # def copy_metrics(engine, engine_name): # for metric_name, metric_value in engine.state.metrics.items(): # hparam_dict["{}_{}".format(engine_name, metric_name)] = metric_value # copy_metrics(trainer, "train") # copy_metrics(validator, "val") print("Writing TB hparams") writer.add_hparams(hparam_dict, {}) # Save model to disk if save: print("Saving model...") models.save_model(base_dir, run_name, model_name, experiment_mode, hparam_dict, trainer, model, optimizer) # Write graph to TB if write_graph: print("Writing TB graph...") tb_write_graph(writer, model, train_dataloader, device) # Write embeddings to TB if write_emb: print("Writing TB embeddings...") image_size = 256 if write_emb_img else 0 # FIXME: be able to select images (balanced, train vs val, etc) image_list = list(train_dataset.label_index["FileName"])[:1000] # disease = chosen_diseases[0] # positive = train_dataset.label_index[train_dataset.label_index[disease] == 1] # negative = train_dataset.label_index[train_dataset.label_index[disease] == 0] # positive_images = list(positive["FileName"])[:25] # negative_images = list(negative["FileName"])[:25] # image_list = positive_images + negative_images all_images, all_embeddings, all_predictions, all_ground_truths = gen_embeddings( model, train_dataset, device, image_list=image_list, image_size=image_size) tb_write_embeddings( writer, chosen_diseases, all_images, all_embeddings, all_predictions, all_ground_truths, global_step=n_epochs, use_images=write_emb_img, tag="1000_{}".format("img" if write_emb_img else "no_img"), ) # Save confusion matrices (is expensive to calculate them afterwards) if save_cms: print("Saving confusion matrices...") # Assure folder cms_dir = os.path.join(base_dir, "cms", experiment_mode) os.makedirs(cms_dir, exist_ok=True) base_fname = os.path.join(cms_dir, run_name) n_diseases = len(chosen_diseases) def extract_cms(metrics): """Extract confusion matrices from a metrics dict.""" cms = [] for disease in chosen_diseases: key = "cm_" + disease if key not in metrics: cm = np.array([[-1, -1], [-1, -1]]) else: cm = metrics[key].numpy() cms.append(cm) return np.array(cms) # Train confusion matrix train_cms = extract_cms(trainer.state.metrics) np.save(base_fname + "_train", train_cms) tb_write_cms(writer, "train", chosen_diseases, train_cms) # Validation confusion matrix val_cms = extract_cms(validator.state.metrics) np.save(base_fname + "_val", val_cms) tb_write_cms(writer, "val", chosen_diseases, val_cms) # All confusion matrix (train + val) all_cms = train_cms + val_cms np.save(base_fname + "_all", all_cms) # Print to console if len(chosen_diseases) == 1: print("Train CM: ") print(train_cms[0]) print("Val CM: ") print(val_cms[0]) # print("Train CM 2: ") # print(trainer.state.metrics["cm_" + chosen_diseases[0]]) # print("Val CM 2: ") # print(validator.state.metrics["cm_" + chosen_diseases[0]]) if write_img: # NOTE: this option is not recommended, use Testing notebook to plot and analyze images print("Writing images to TB...") test_dataset, test_dataloader = utilsT.prepare_data( dataset_dir, "test", chosen_diseases, batch_size, max_images=test_max_images, ) # TODO: add a way to select images? # image_list = list(test_dataset.label_index["FileName"])[:3] # Examples in test_dataset (with bboxes available): image_list = [ # "00010277_000.png", # (Effusion, Infiltrate, Mass, Pneumonia) # "00018427_004.png", # (Atelectasis, Effusion, Mass) # "00021703_001.png", # (Atelectasis, Effusion, Infiltrate) # "00028640_008.png", # (Effusion, Infiltrate) # "00019124_104.png", # (Pneumothorax) # "00019124_090.png", # (Nodule) # "00020318_007.png", # (Pneumothorax) "00000003_000.png", # (0) # "00000003_001.png", # (0) # "00000003_002.png", # (0) "00000732_005.png", # (Cardiomegaly, Pneumothorax) # "00012261_001.png", # (Cardiomegaly, Pneumonia) # "00013249_033.png", # (Cardiomegaly, Pneumonia) # "00029808_003.png", # (Cardiomegaly, Pneumonia) # "00022215_012.png", # (Cardiomegaly, Pneumonia) # "00011402_007.png", # (Cardiomegaly, Pneumonia) # "00019018_007.png", # (Cardiomegaly, Infiltrate) # "00021009_001.png", # (Cardiomegaly, Infiltrate) # "00013670_151.png", # (Cardiomegaly, Infiltrate) # "00005066_030.png", # (Cardiomegaly, Infiltrate, Effusion) "00012288_000.png", # (Cardiomegaly) "00008399_007.png", # (Cardiomegaly) "00005532_000.png", # (Cardiomegaly) "00005532_014.png", # (Cardiomegaly) "00005532_016.png", # (Cardiomegaly) "00005827_000.png", # (Cardiomegaly) # "00006912_007.png", # (Cardiomegaly) # "00007037_000.png", # (Cardiomegaly) # "00007043_000.png", # (Cardiomegaly) # "00012741_004.png", # (Cardiomegaly) # "00007551_020.png", # (Cardiomegaly) # "00007735_040.png", # (Cardiomegaly) # "00008339_010.png", # (Cardiomegaly) # "00008365_000.png", # (Cardiomegaly) # "00012686_003.png", # (Cardiomegaly) ] tb_write_images(writer, model, test_dataset, chosen_diseases, n_epochs, device, image_list) # Close TB writer if experiment_mode != "debug": writer.close() # Run post_train print("-" * 50) print("Running post_train...") print("Loading test dataset...") test_dataset, test_dataloader = utilsT.prepare_data( dataset_dir, "test", chosen_diseases, batch_size, max_images=test_max_images) save_cms_with_names(run_name, experiment_mode, model, test_dataset, test_dataloader, chosen_diseases) evaluate_model(run_name, model, optimizer, device, loss_name, loss_params, chosen_diseases, test_dataloader, experiment_mode=experiment_mode, base_dir=base_dir) # Return values for debugging model_run = ModelRun(model, run_name, model_name, chosen_diseases) if experiment_mode == "debug": model_run.save_debug_data(writer, trainer, validator, train_dataset, train_dataloader, val_dataset, val_dataloader) return model_run
def do_train(cfg, model, train_loader, val_loader, optimizer, loss_fns, n_fold=0): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD epochs = cfg.SOLVER.MAX_EPOCHS device = cfg.MODEL.DEVICE output_dir = cfg.OUTPUT_DIR lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=1e-3, patience=3, min_lr=5e-6, eps=1e-08, verbose=True) logger = logging.getLogger("MOA_MLP.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fns, device=device) evaluator = create_supervised_evaluator(model, metrics={ 'log_loss': LogLoss(loss_fns), 'cv_score': CV_Score() }, device=device) checkpointer = ModelCheckpoint(output_dir, 'moa_mlp_' + str(n_fold), n_saved=100, require_empty=False) timer = Timer(average=True) # automatically adding handlers via a special `attach` method of `RunningAverage` handler RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss') # automatically adding handlers via a special `attach` method of `Checkpointer` handler trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_period), checkpointer, { 'model': model, 'optimizer': optimizer }) trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: lr_scheduler.step) # automatically adding handlers via a special `attach` method of `Timer` handler timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # adding handlers using `trainer.on` decorator API @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: logger.info( "K:[{}] Epoch[{}] Iteration[{}/{}] LR: {} Log Loss: {:.3f}". format(n_fold, engine.state.epoch, iter, len(train_loader), optimizer.param_groups[0]['lr'], engine.state.metrics['avg_loss'])) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics log_loss = metrics['log_loss'] cv_score = metrics['cv_score'] logger.info( "Training Results - K:[{}] Epoch: {} LR: {} Log Loss: {:.3f} CV Score: {:.3f}" .format(n_fold, engine.state.epoch, optimizer.param_groups[0]['lr'], log_loss, cv_score)) if val_loader is not None: # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics log_loss = metrics['log_loss'] cv_score = metrics['cv_score'] logger.info( "Validation Results - K:[{}] Epoch: {} LR: {} Log Loss: {:.3f} CV Score: {:.3f}" .format(n_fold, engine.state.epoch, optimizer.param_groups[0]['lr'], log_loss, cv_score)) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'K:[{}] Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(n_fold, engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) timer.reset() trainer.run(train_loader, max_epochs=epochs)
def do_train( cfg, model, train_loader, val_loader, optimizer, loss_fn, ): log_period = cfg.SOLVER.LOG_PERIOD output_dir = cfg.OUTPUT_DIR device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) if cfg.MODEL.DEVICE == 'cuda' else 'cpu' epochs = cfg.SOLVER.MAX_EPOCHS logger.info("Start training") logger.info("use {}".format(device)) trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': Accuracy(), 'ce_loss': Loss(loss_fn) }, device=device) checkpointer = ModelCheckpoint(output_dir, 'mnist', n_saved=10, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpointer, {'model': model}) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss') @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format( engine.state.epoch, iter, len(train_loader), engine.state.metrics['avg_loss'])) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['ce_loss'] logger.info( "Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" .format(engine.state.epoch, avg_accuracy, avg_loss)) if val_loader is not None: @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['ce_loss'] logger.info( "Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" .format(engine.state.epoch, avg_accuracy, avg_loss)) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) timer.reset() trainer.run(train_loader, max_epochs=epochs)
def do_train( cfg, model, train_loader, val_loader, classes_list, optimizers, schedulers, loss_fn, start_epoch ): #1.先把cfg中的参数导出 epochs = cfg.SOLVER.MAX_EPOCHS log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.SOLVER.OUTPUT_DIR device = cfg.MODEL.DEVICE #2.构建模块 logger = logging.getLogger("fundus_prediction.train") logger.info("Start training") # TensorBoard setup writer_train = {} for i in range(len(optimizers)): writer_train[i] = SummaryWriter(cfg.SOLVER.OUTPUT_DIR + "/summary/train/" + str(i)) writer_val = SummaryWriter(cfg.SOLVER.OUTPUT_DIR + "/summary/val") writer_train["graph"] = SummaryWriter(cfg.SOLVER.OUTPUT_DIR + "/summary/train/graph") try: #print(model) images, labels = next(iter(train_loader)) grid = torchvision.utils.make_grid(images) writer_train["graph"].add_image('images', grid, 0) writer_train["graph"].add_graph(model, images) writer_train["graph"].flush() except Exception as e: print("Failed to save model graph: {}".format(e)) # 设置训练相关的metrics metrics_train = {"avg_total_loss": RunningAverage(output_transform=lambda x: x["total_loss"]), "avg_precision": RunningAverage(Precision(output_transform=lambda x: (x["scores"], x["labels"]))), "avg_accuracy": RunningAverage(Accuracy(output_transform=lambda x: (x["scores"], x["labels"]))), #由于训练集样本均衡后远离原始样本集,故只采用平均metric } lossKeys = cfg.LOSS.TYPE.split(" ") # 设置loss相关的metrics for lossName in lossKeys: if lossName == "similarity_loss": metrics_train["AVG-" + "similarity_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["similarity_loss"]) elif lossName == "ranked_loss": metrics_train["AVG-" + "ranked_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["ranked_loss"]) elif lossName == "cranked_loss": metrics_train["AVG-" + "cranked_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["cranked_loss"]) elif lossName == "cross_entropy_loss": metrics_train["AVG-" + "cross_entropy_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["cross_entropy_loss"]) elif lossName == "cluster_loss": metrics_train["AVG-" + "cluster_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["cluster_loss"][0]) elif lossName == "one_vs_rest_loss": metrics_train["AVG-" + "one_vs_rest_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["one_vs_rest_loss"]) elif lossName == "attention_loss": metrics_train["AVG-" + "attention_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["attention_loss"]) elif lossName == "class_predict_loss": metrics_train["AVG-" + "class_predict_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["class_predict_loss"]) elif lossName == "kld_loss": metrics_train["AVG-" + "kld_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["kld_loss"]) elif lossName == "margin_loss": metrics_train["AVG-" + "margin_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["margin_loss"]) elif lossName == "cross_entropy_multilabel_loss": metrics_train["AVG-" + "cross_entropy_multilabel_loss"] = RunningAverage( output_transform=lambda x: x["losses"]["cross_entropy_multilabel_loss"]) else: raise Exception('expected METRIC_LOSS_TYPE should be similarity_loss, ranked_loss, cranked_loss' 'but got {}'.format(cfg.LOSS.TYPE)) trainer = create_supervised_trainer(model, optimizers, metrics_train, loss_fn, device=device) #CJY at 2019.9.26 def output_transform(output): # `output` variable is returned by above `process_function` y_pred = output['scores'] y = output['labels'] return y_pred, y # output format is according to `Accuracy` docs metrics_eval = {"overall_accuracy": Accuracy(output_transform=output_transform), "precision": Precision(output_transform=output_transform)} checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=100, require_empty=False, start_step=start_epoch) timer = Timer(average=True) #3.将模块与engine联系起来attach #CJY at 2019.9.23 trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, 'optimizer': optimizers[0]}) #trainer.add_event_handler(Events.STARTED, checkpointer, {'model': model, # 'optimizer': optimizers[0]}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) #4.事件处理函数 @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch engine.state.iteration = engine.state.iteration + start_epoch * len(train_loader) """ metrics = do_inference(cfg, model, val_loader, classes_list, loss_fn, plotFlag=False) step = 0#len(train_loader) * (engine.state.epoch - 1) + engine.state.iteration for preKey in metrics['precision'].keys(): writer_val.add_scalar("Precision/" + str(preKey), metrics['precision'][preKey], step) for recKey in metrics['recall'].keys(): writer_val.add_scalar("Recall/" + str(recKey), metrics['recall'][recKey], step) for aucKey in metrics['roc_auc'].keys(): writer_val.add_scalar("ROC_AUC/" + str(aucKey), metrics['roc_auc'][aucKey], step) writer_val.add_scalar("OverallAccuracy", metrics["overall_accuracy"], step) # writer.add_scalar("Val/"+"confusion_matrix", metrics['confusion_matrix'], step) # 混淆矩阵 和 ROC曲线可以用图的方式来存储 roc_numpy = metrics["roc_figure"] writer_val.add_image("ROC", roc_numpy, step, dataformats='HWC') confusion_matrix_numpy = metrics["confusion_matrix_numpy"] writer_val.add_image("ConfusionMatrix", confusion_matrix_numpy, step, dataformats='HWC') writer_val.flush() #""" @trainer.on(Events.EPOCH_COMPLETED) #_STARTED) #注意,在pytorch1.2里面 scheduler.steo()应该放到 optimizer.step()之后 def adjust_learning_rate(engine): """ #if (engine.state.epoch - 1) % engine.state.epochs_traverse_optimizers == 0: if engine.state.epoch == 2: op_i_scheduler1 = WarmupMultiStepLR(optimizers[0], cfg.SOLVER.SCHEDULER.STEPS, cfg.SOLVER.SCHEDULER.GAMMA, cfg.SOLVER.SCHEDULER.WARMUP_FACTOR, cfg.SOLVER.SCHEDULER.WARMUP_ITERS, cfg.SOLVER.SCHEDULER.WARMUP_METHOD) op_i_scheduler2 = WarmupMultiStepLR(optimizers[1], cfg.SOLVER.SCHEDULER.STEPS, cfg.SOLVER.SCHEDULER.GAMMA, cfg.SOLVER.SCHEDULER.WARMUP_FACTOR, cfg.SOLVER.SCHEDULER.WARMUP_ITERS, cfg.SOLVER.SCHEDULER.WARMUP_METHOD) engine.state.schedulers = [op_i_scheduler1, op_i_scheduler2] print("copy") """ schedulers[engine.state.schedulers_epochs_index][engine.state.optimizer_index].step() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): global ITER ITER += 1 if ITER % (log_period*accumulation_steps) == 0: step = engine.state.iteration #写入train-summary #记录avg-presicion avg_precision = engine.state.metrics['avg_precision'].numpy().tolist() avg_precisions = {} ap_sum = 0 for index, ap in enumerate(avg_precision): avg_precisions[index] = float("{:.2f}".format(ap)) ap_sum += avg_precisions[index] scalarDict = {} for i in range(len(optimizers)): if i != engine.state.optimizer_index: scalarDict["optimizer" + str(i)] = 0 else: scalarDict["optimizer" + str(i)] = avg_precisions[index] writer_train[i].add_scalar("Precision/" + str(index), scalarDict["optimizer" + str(i)], step) writer_train[i].flush() avg_precisions["avg_precision"] = float("{:.2f}".format(ap_sum/len(avg_precision))) #记录avg-loss avg_losses = {} for lossName in lossKeys: avg_losses[lossName] = (float("{:.3f}".format(engine.state.metrics["AVG-" + lossName]))) scalarDict = {} for i in range(len(optimizers)): if i != engine.state.optimizer_index: scalarDict["optimizer" + str(i)] = 0 else: scalarDict["optimizer" + str(i)] = avg_losses[lossName] writer_train[i].add_scalar("Loss/" + lossName, scalarDict["optimizer" + str(i)], step) writer_train[i].flush() #记录其余标量 scalar_list = ["avg_accuracy", "avg_total_loss"] for scalar in scalar_list: scalarDict = {} for i in range(len(optimizers)): if i != engine.state.optimizer_index: scalarDict["optimizer" + str(i)] = 0 else: scalarDict["optimizer" + str(i)] = engine.state.metrics[scalar] writer_train[i].add_scalar("Train/" + scalar, scalarDict["optimizer" + str(i)], step) writer_train[i].flush() #记录学习率 LearningRateDict = {} for i in range(len(optimizers)): if i != engine.state.optimizer_index: LearningRateDict["optimizer" + str(i)] = 0 else: LearningRateDict["optimizer" + str(i)] = schedulers[engine.state.schedulers_epochs_index][engine.state.optimizer_index].get_lr()[0] writer_train[i].add_scalar("Train/" + "LearningRate", LearningRateDict["optimizer" + str(i)], step) writer_train[i].flush() #记录weight choose_list = ["base.conv1.weight", "base.bn1.weight", "base.layer1.0.conv1.weight", "base.layer1.2.conv3.weight", "base.layer2.0.conv1.weight", "base.layer2.3.conv3.weight", "base.layer3.0.conv1.weight", "base.layer3.5.conv3.weight", "base.layer4.0.conv1.weight", "base.layer4.2.conv1.weight", "bottleneck.weight", "classifier.weight"] """ #记录参数分布 非常耗时 params_dict = {} for name, parameters in model.named_parameters(): #print(name, ':', parameters.size()) params_dict[name] = parameters.detach().cpu().numpy() #print(len(params_dict)) for cp in params_dict.keys(): writer_train["graph"].add_histogram("Train/" + cp, params_dict[cp], step) writer_train["graph"].flush() #""" logger.info("Epoch[{}] Iteration[{}/{}] Training {} - ATLoss: {:.3f}, AvgLoss: {}, Avg Pre: {}, Avg_Acc: {:.3f}, Base Lr: {:.2e}, step: {}" .format(engine.state.epoch, ITER, len(train_loader), engine.state.losstype, engine.state.metrics['avg_total_loss'], avg_losses, avg_precisions, engine.state.metrics['avg_accuracy'], schedulers[engine.state.schedulers_epochs_index][engine.state.optimizer_index].get_lr()[0], step)) #logger.info(engine.state.output["rf_loss"]) if engine.state.output["losses"].get("cluster_loss") != None: logger.info("Epoch[{}] Iteration[{}/{}] Center {} \n r_inter: {}, r_outer: {}, step: {}" .format(engine.state.epoch, ITER, len(train_loader), engine.state.output["losses"]["cluster_loss"][-1]["center"].cpu().detach().numpy(), engine.state.output["losses"]["cluster_loss"][-1]["r_inter"].item(), engine.state.output["losses"]["cluster_loss"][-1]["r_outer"].item(), step)) if len(train_loader) == ITER: ITER = 0 # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): if engine.state.epoch % eval_period == 0: metrics = do_inference(cfg, model, val_loader, classes_list, loss_fn) step = engine.state.iteration for preKey in metrics['precision'].keys(): writer_val.add_scalar("Precision/" + str(preKey), metrics['precision'][preKey], step) for recKey in metrics['recall'].keys(): writer_val.add_scalar("Recall/" + str(recKey), metrics['recall'][recKey], step) for aucKey in metrics['roc_auc'].keys(): writer_val.add_scalar("ROC_AUC/" + str(aucKey), metrics['roc_auc'][aucKey], step) writer_val.add_scalar("OverallAccuracy", metrics["overall_accuracy"], step) #writer.add_scalar("Val/"+"confusion_matrix", metrics['confusion_matrix'], step) #混淆矩阵 和 ROC曲线可以用图的方式来存储 roc_numpy = metrics["roc_figure"] writer_val.add_image("ROC", roc_numpy, step, dataformats='HWC') confusion_matrix_numpy = metrics["confusion_matrix_numpy"] writer_val.add_image("ConfusionMatrix", confusion_matrix_numpy, step, dataformats='HWC') writer_val.flush() #5.engine运行 trainer.run(train_loader, max_epochs=epochs) for key in writer_train.keys(): writer_train[key].close() writer_val.close()
def do_train_with_center( cfg, model, center_criterion, train_loader, val_loader, optimizer, optimizer_center, scheduler, loss_fn, num_query, start_epoch ): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device) evaluator = create_supervised_evaluator(model, metrics={ 'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) timer = Timer(average=True) # trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'optimizer_center': optimizer_center.state_dict()}) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, 'optimizer': optimizer, 'centerloss': center_criterion, 'optimizer_center': optimizer_center}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): scheduler.step() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): global ITER ITER += 1 if ITER % log_period == 0: logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, ITER, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) if len(train_loader) == ITER: ITER = 0 # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): if engine.state.epoch % eval_period == 0: evaluator.run(val_loader) cmc, mAP = evaluator.state.metrics['r1_mAP'] logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) trainer.run(train_loader, max_epochs=epochs)
def get_trainer(model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, log_period=10, save_dir="checkpoints", prefix="model", gallery_loader=None, query_loader=None, eval_interval=None, dataset="sysu"): if logger is None: logger = logging.getLogger() logger.setLevel(logging.WARN) # trainer trainer = create_train_engine(model, optimizer, non_blocking) # checkpoint handler handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True, save_as_state_dict=True, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model}) # metric timer = Timer(average=True) kv_metric = AutoKVMetric() # evaluator evaluator = None if not type(eval_interval) == int: raise TypeError("The parameter 'validate_interval' must be type INT.") if eval_interval > 0 and gallery_loader is not None and query_loader is not None: evaluator = create_eval_engine(model, non_blocking) @trainer.on(Events.EPOCH_STARTED) def epoch_started_callback(engine): kv_metric.reset() timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def epoch_completed_callback(engine): epoch = engine.state.epoch if lr_scheduler is not None: lr_scheduler.step() if epoch % eval_interval == 0: logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch)) if evaluator and epoch % eval_interval == 0: torch.cuda.empty_cache() # extract query feature evaluator.run(query_loader) q_feats = torch.cat(evaluator.state.feat_list, dim=0) q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() # extract gallery feature evaluator.run(gallery_loader) g_feats = torch.cat(evaluator.state.feat_list, dim=0) g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) if dataset == "sysu": perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[ 'rand_perm_cam'] mAP, r1, r5, _, _ = eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm) else: mAP, r1, r5, _, _ = eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams) if writer is not None: writer.add_scalar('eval/mAP', mAP, epoch) writer.add_scalar('eval/r1', r1, epoch) writer.add_scalar('eval/r5', r5, epoch) evaluator.state.feat_list.clear() evaluator.state.id_list.clear() evaluator.state.cam_list.clear() evaluator.state.img_path_list.clear() del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams torch.cuda.empty_cache() @trainer.on(Events.ITERATION_COMPLETED) def iteration_complete_callback(engine): timer.step() kv_metric.update(engine.state.output) epoch = engine.state.epoch iteration = engine.state.iteration iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader) if iter_in_epoch % log_period == 0: batch_size = engine.state.batch[0].size(0) speed = batch_size / timer.value() msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed) metric_dict = kv_metric.compute() # log output information if logger is not None: for k in sorted(metric_dict.keys()): msg += "\t%s: %.4f" % (k, metric_dict[k]) if writer is not None: writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration) logger.info(msg) kv_metric.reset() timer.reset() return trainer
def run(opt): # logging.basicConfig(filename=os.path.join(opt.log_dir, opt.log_file), level=logging.INFO) # logger = logging.getLogger() # # logger.addHandler(logging.StreamHandler()) # logger = logger.info log = Logger(filename=os.path.join(opt.log_dir, opt.log_file), level='debug') logger = log.logger.info # Decide what attrs to train attr, attr_name = get_tasks(opt) # Generate model based on tasks logger('Loading models') model, parameters, mean, std = generate_model(opt, attr) # parameters[0]['lr'] = 0 # parameters[1]['lr'] = opt.lr / 3 logger('Loading dataset') train_loader, val_loader = get_data(opt, attr, mean, std) writer = create_summary_writer(model, train_loader, opt.log_dir) # have to after writer model = nn.DataParallel(model, device_ids=None) # Learning configurations if opt.optimizer == 'sgd': optimizer = SGD(parameters, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay, nesterov=opt.nesterov) elif opt.optimizer == 'adam': optimizer = Adam(parameters, lr=opt.lr, betas=opt.betas) else: raise Exception("Not supported") scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=opt.lr_patience, factor=opt.factor, min_lr=1e-6) # Loading checkpoint if opt.checkpoint: logger('loading checkpoint {}'.format(opt.checkpoint)) checkpoint = torch.load(opt.checkpoint) opt.begin_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) device = 'cuda' loss_fns, metrics = get_losses_metrics(attr, opt.categorical_loss, opt.at, opt.at_loss) trainer = my_trainer( model, optimizer, lambda pred, target, epoch: multitask_loss( pred, target, loss_fns, len(attr_name), opt.at_coe, epoch), device=device) train_evaluator = create_supervised_evaluator( model, metrics={'multitask': MultiAttributeMetric(metrics, attr_name)}, device=device) val_evaluator = create_supervised_evaluator( model, metrics={'multitask': MultiAttributeMetric(metrics, attr_name)}, device=device) # Training timer handlers model_timer, data_timer = Timer(average=True), Timer(average=True) model_timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) data_timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_COMPLETED, pause=Events.ITERATION_STARTED, step=Events.ITERATION_STARTED) # Training log/plot handlers @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter_num = (engine.state.iteration - 1) % len(train_loader) + 1 if iter_num % opt.log_interval == 0: logger( "Epoch[{}] Iteration[{}/{}] Sum Loss: {:.2f} Cls Loss: {:.2f} At Loss: {:.2f} " "Coe: {:.2f} Model Process: {:.3f}s/batch Data Preparation: {:.3f}s/batch" .format(engine.state.epoch, iter_num, len(train_loader), engine.state.output['sum'], engine.state.output['cls'], engine.state.output['at'], engine.state.output['coe'], model_timer.value(), data_timer.value())) writer.add_scalar("training/loss", engine.state.output['sum'], engine.state.iteration) # Log/Plot Learning rate @trainer.on(Events.EPOCH_STARTED) def log_learning_rate(engine): lr = optimizer.param_groups[-1]['lr'] logger('Epoch[{}] Starts with lr={}'.format(engine.state.epoch, lr)) writer.add_scalar("learning_rate", lr, engine.state.epoch) # Checkpointing @trainer.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): if engine.state.epoch % opt.save_interval == 0: save_file_path = os.path.join( opt.log_dir, 'save_{}.pth'.format(engine.state.epoch)) states = { 'epoch': engine.state.epoch, 'arch': opt.model, 'state_dict': model.module.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(states, save_file_path) # model.eval() # example = torch.rand(1, 3, 224, 224) # traced_script_module = torch.jit.trace(model, example) # traced_script_module.save(save_file_path) # model.train() # torch.save(model._modules.state_dict(), save_file_path) # val_evaluator event handlers @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): data_list = [train_loader, val_loader] name_list = ['train', 'val'] eval_list = [train_evaluator, val_evaluator] for data, name, evl in zip(data_list, name_list, eval_list): evl.run(data) metrics_info = evl.state.metrics["multitask"] for m, val in metrics_info['metrics'].items(): writer.add_scalar(name + '_metrics/{}'.format(m), val, engine.state.epoch) for m, val in metrics_info['summaries'].items(): writer.add_scalar(name + '_summary/{}'.format(m), val, engine.state.epoch) logger( name + ": Validation Results - Epoch: {}".format(engine.state.epoch)) print_summar_table(logger, attr_name, metrics_info['logger']) # Update Learning Rate if name == 'train': scheduler.step(metrics_info['logger']['attr']['ap'][-1]) # kick everything off logger('Start training') trainer.run(train_loader, max_epochs=opt.n_epochs) writer.close()
def main(config, needs_save, study_name, k, n_splits, output_dir_path): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices seed = check_manual_seed(config.run.seed) print('Using seed: {}'.format(seed)) train_data_loader, test_data_loader, data_train = get_k_hold_data_loader( config.dataset, k=k, n_splits=n_splits, ) data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True) data_train = torch.t(data_train) model = get_model(config.model) model.cuda() model = nn.DataParallel(model) criterion = nn.CrossEntropyLoss() if config.optimizer.optimizer_name == 'Adam': optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), config.optimizer.lr, [0.9, 0.9999], weight_decay=config.optimizer.weight_decay, ) else: raise NotImplementedError # scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 ** epoch) def update(engine, batch): model.train() x = batch['data'].float().cuda(non_blocking=True) y = batch['label'].long().cuda(non_blocking=True) if config.run.transposed_matrix == 'overall': x_t = data_train elif config.run.transposed_matrix == 'batch': x_t = torch.t(x) def closure(): optimizer.zero_grad() if 'MLP' in config.model.model_name: out, x_hat = model(x) else: out, x_hat = model(x, x_t) l_discriminative = criterion(out, y) l_feature = torch.tensor(0.0).cuda() if config.run.w_feature_selection: l_feature += config.run.w_feature_selection * torch.sum(torch.abs(model.module.Ue)) l_recon = torch.tensor(0.0).cuda() if config.run.w_reconstruction: l_recon += config.run.w_reconstruction * F.mse_loss(x, x_hat) l_total = l_discriminative + l_feature + l_recon l_total.backward() return l_total, l_discriminative, l_feature, l_recon, out l_total, l_discriminative, l_feature, l_recon, out = optimizer.step(closure) metrics = calc_metrics(out, y) metrics.update({ 'l_total': l_total.item(), 'l_discriminative': l_discriminative.item(), 'l_feature': l_feature.item(), 'l_recon': l_recon.item(), }) torch.cuda.synchronize() return metrics def inference(engine, batch): model.eval() x = batch['data'].float().cuda(non_blocking=True) y = batch['label'].long().cuda(non_blocking=True) if config.run.transposed_matrix == 'overall': x_t = data_train elif config.run.transposed_matrix == 'batch': x_t = torch.t(x) with torch.no_grad(): if 'MLP' in config.model.model_name: out, x_hat = model(x) else: out, x_hat = model(x, x_t) l_discriminative = criterion(out, y) l_feature = torch.tensor(0.0).cuda() if config.run.w_feature_selection: l_feature += config.run.w_feature_selection * torch.sum(torch.abs(model.module.Ue)) l_recon = torch.tensor(0.0).cuda() if config.run.w_reconstruction: l_recon += config.run.w_reconstruction * F.mse_loss(x, x_hat) l_total = l_discriminative + l_feature + l_recon metrics = calc_metrics(out, y) metrics.update({ 'l_total': l_total.item(), 'l_discriminative': l_discriminative.item(), 'l_feature': l_feature.item(), 'l_recon': l_recon.item(), }) torch.cuda.synchronize() return metrics trainer = Engine(update) evaluator = Engine(inference) timer = Timer(average=True) monitoring_metrics = ['l_total', 'l_discriminative', 'l_feature', 'l_recon', 'accuracy'] for metric in monitoring_metrics: RunningAverage( alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric) ).attach(trainer, metric) for metric in monitoring_metrics: RunningAverage( alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric) ).attach(evaluator, metric) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) pbar.attach(evaluator, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def events_started(engine): if needs_save: save_config(config, seed, output_dir_path) @trainer.on(Events.EPOCH_COMPLETED) def switch_training_to_evaluation(engine): if needs_save: save_logs('train', k, n_splits, trainer, trainer.state.epoch, trainer.state.iteration, config, output_dir_path) evaluator.run(test_data_loader, max_epochs=1) @evaluator.on(Events.EPOCH_COMPLETED) def switch_evaluation_to_training(engine): if needs_save: save_logs('val', k, n_splits, evaluator, trainer.state.epoch, trainer.state.iteration, config, output_dir_path) if trainer.state.epoch % 100 == 0: save_models(model, optimizer, k, n_splits, trainer.state.epoch, trainer.state.iteration, config, output_dir_path) # scheduler.step() @trainer.on(Events.EPOCH_COMPLETED) @evaluator.on(Events.EPOCH_COMPLETED) def show_logs(engine): columns = ['k', 'n_splits', 'epoch', 'iteration'] + list(engine.state.metrics.keys()) values = [str(k), str(n_splits), str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch, max_epoch=config.run.n_epochs, i=engine.state.iteration, max_i=len(train_data_loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.run.n_epochs, config.run.n_epochs * len(train_data_loader)) ) trainer.run(train_data_loader, config.run.n_epochs)
def run(path, model_name, imgaugs, train_batch_size, val_batch_size, num_workers, epochs, optim, lr, lr_update_every, gamma, restart_every, restart_factor, init_lr_factor, lr_reduce_patience, early_stop_patience, log_interval, output, debug): print("--- Cifar10 Playground : Training --- ") from datetime import datetime now = datetime.now() log_dir = os.path.join( output, "training_{}_{}".format(model_name, now.strftime("%Y%m%d_%H%M"))) if not os.path.exists(log_dir): os.makedirs(log_dir) log_level = logging.INFO if debug: log_level = logging.DEBUG print("Activated debug mode") logger = logging.getLogger("Cifar10 Playground: Train") setup_logger(logger, log_dir, log_level) logger.debug("Setup tensorboard writer") writer = SummaryWriter(log_dir=os.path.join(log_dir, "tensorboard")) save_conf(logger, writer, model_name, imgaugs, train_batch_size, val_batch_size, num_workers, epochs, optim, lr, lr_update_every, gamma, restart_every, restart_factor, init_lr_factor, lr_reduce_patience, early_stop_patience, log_dir) device = 'cpu' if torch.cuda.is_available(): logger.debug("CUDA is enabled") from torch.backends import cudnn cudnn.benchmark = True device = 'cuda' logger.debug("Setup model: {}".format(model_name)) if not os.path.isfile(model_name): assert model_name in MODEL_MAP, "Model name not in {}".format( MODEL_MAP.keys()) model = MODEL_MAP[model_name](num_classes=10) else: model = torch.load(model_name) model_name = model.__class__.__name__ if 'cuda' in device: model = model.to(device) logger.debug("Setup train/val dataloaders") train_loader, val_loader = get_data_loaders(path, imgaugs, train_batch_size, val_batch_size, num_workers, device=device) write_model_graph(writer, model, train_loader, device=device) logger.debug("Setup optimizer") assert optim in OPTIMIZER_MAP, "Optimizer name not in {}".format( OPTIMIZER_MAP.keys()) optimizer = OPTIMIZER_MAP[optim](model.parameters(), lr=lr) logger.debug("Setup criterion") criterion = nn.CrossEntropyLoss() if 'cuda' in device: criterion = criterion.cuda() lr_scheduler = ExponentialLR(optimizer, gamma=gamma) lr_scheduler_restarts = LRSchedulerWithRestart( lr_scheduler, restart_every=restart_every, restart_factor=restart_factor, init_lr_factor=init_lr_factor) reduce_on_plateau = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=lr_reduce_patience, threshold=0.01, verbose=True) logger.debug("Setup ignite trainer and evaluator") trainer = create_supervised_trainer(model, optimizer, criterion, device=device) metrics = { 'accuracy': CategoricalAccuracy(), 'precision': Precision(), 'recall': Recall(), 'nll': Loss(criterion) } train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) logger.debug("Setup handlers") # Setup timer to measure training time timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_interval == 0: logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.4f}".format( engine.state.epoch, iter, len(train_loader), engine.state.output)) writer.add_scalar("training/loss_vs_iterations", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_STARTED) def update_lr_schedulers(engine): if (engine.state.epoch - 1) % lr_update_every == 0: lr_scheduler_restarts.step() @trainer.on(Events.EPOCH_STARTED) def log_lrs(engine): if len(optimizer.param_groups) == 1: lr = float(optimizer.param_groups[0]['lr']) writer.add_scalar("learning_rate", lr, engine.state.epoch) logger.debug("Learning rate: {}".format(lr)) else: for i, param_group in enumerate(optimizer.param_groups): lr = float(param_group['lr']) logger.debug("Learning rate (group {}): {}".format(i, lr)) writer.add_scalar("learning_rate_group_{}".format(i), lr, engine.state.epoch) log_images_dir = os.path.join(log_dir, "figures") os.makedirs(log_images_dir) def log_precision_recall_results(metrics, epoch, mode): for metric_name in ['precision', 'recall']: value = metrics[metric_name] avg_value = torch.mean(value).item() writer.add_scalar("{}/avg_{}".format(mode, metric_name), avg_value, epoch) # Save metric per class figure sorted_values = value.to('cpu').numpy() indices = np.argsort(sorted_values) sorted_values = sorted_values[indices] n_classes = len(sorted_values) classes = np.array( ["class_{}".format(i) for i in range(n_classes)]) sorted_classes = classes[indices] fig = create_fig_param_per_class(sorted_values, metric_name, classes=sorted_classes, n_classes_per_fig=20) fname = os.path.join( log_images_dir, "{}_{}_{}_per_class.png".format(mode, epoch, metric_name)) fig.savefig(fname) tag = "{}_{}".format(mode, metric_name) writer.add_figure(tag, fig, epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_training_metrics(engine): epoch = engine.state.epoch logger.info("One epoch training time (seconds): {}".format( timer.value())) metrics = train_evaluator.run(train_loader).metrics logger.info( "Training Results - Epoch: {} Avg accuracy: {:.4f} Avg loss: {:.4f}" .format(engine.state.epoch, metrics['accuracy'], metrics['nll'])) writer.add_scalar("training/avg_accuracy", metrics['accuracy'], epoch) writer.add_scalar("training/avg_error", 1.0 - metrics['accuracy'], epoch) writer.add_scalar("training/avg_loss", metrics['nll'], epoch) log_precision_recall_results(metrics, epoch, "training") @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): epoch = engine.state.epoch metrics = val_evaluator.run(val_loader).metrics writer.add_scalar("validation/avg_loss", metrics['nll'], epoch) writer.add_scalar("validation/avg_accuracy", metrics['accuracy'], epoch) writer.add_scalar("validation/avg_error", 1.0 - metrics['accuracy'], epoch) logger.info( "Validation Results - Epoch: {} Avg accuracy: {:.4f} Avg loss: {:.4f}" .format(engine.state.epoch, metrics['accuracy'], metrics['nll'])) log_precision_recall_results(metrics, epoch, "validation") @val_evaluator.on(Events.COMPLETED) def update_reduce_on_plateau(engine): val_loss = engine.state.metrics['nll'] reduce_on_plateau.step(val_loss) def score_function(engine): val_loss = engine.state.metrics['nll'] # Objects with highest scores will be retained. return -val_loss # Setup early stopping: handler = EarlyStopping(patience=early_stop_patience, score_function=score_function, trainer=trainer) setup_logger(handler._logger, log_dir, log_level) val_evaluator.add_event_handler(Events.COMPLETED, handler) # Setup model checkpoint: best_model_saver = ModelCheckpoint(log_dir, filename_prefix="model", score_name="val_loss", score_function=score_function, n_saved=5, atomic=True, create_dir=True) val_evaluator.add_event_handler(Events.COMPLETED, best_model_saver, {model_name: model}) last_model_saver = ModelCheckpoint(log_dir, filename_prefix="checkpoint", save_interval=1, n_saved=1, atomic=True, create_dir=True) trainer.add_event_handler(Events.COMPLETED, last_model_saver, {model_name: model}) logger.info("Start training: {} epochs".format(epochs)) try: trainer.run(train_loader, max_epochs=epochs) except KeyboardInterrupt: logger.info("Catched KeyboardInterrupt -> exit") except Exception as e: # noqa logger.exception("") if args.debug: try: # open an ipython shell if possible import IPython IPython.embed() # noqa except ImportError: print("Failed to start IPython console") logger.debug("Training is ended") writer.close()
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn, num_query, start_epoch): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, gamma=cfg.MODEL.GAMMA, margin=cfg.SOLVER.MARGIN, beta=cfg.MODEL.BETA) if cfg.TEST.PAIR == "no": evaluator = create_supervised_evaluator( model, metrics={ 'r1_mAP': R1_mAP(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) }, device=device) elif cfg.TEST.PAIR == "yes": evaluator = create_supervised_evaluator( model, metrics={ 'r1_mAP': R1_mAP_pair(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) }, device=device) checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) # checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=10, require_empty=False) timer = Timer(average=True) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, { 'model': model, 'optimizer': optimizer }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): scheduler.step() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): global ITER ITER += 1 if ITER % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, ITER, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) if len(train_loader) == ITER: ITER = 0 # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): # multi_person_training_info2() train_loader, val_loader, num_query, num_classes = make_data_loader_train( cfg) logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): # if engine.state.epoch % eval_period == 0: if engine.state.epoch >= eval_period: all_cmc = [] all_AP = [] num_valid_q = 0 q_pids = [] for query_index in tqdm(range(num_query)): val_loader = make_data_loader_val(cfg, query_index, dataset) evaluator.run(val_loader) cmc, AP, q_pid = evaluator.state.metrics['r1_mAP'] if AP >= 0: if cmc.shape[0] < 50: continue num_valid_q += 1 all_cmc.append(cmc) all_AP.append(AP) q_pids.append(int(q_pid)) else: continue all_cmc = np.asarray(all_cmc).astype(np.float32) cmc = all_cmc.sum(0) / num_valid_q mAP = np.mean(all_AP) logger.info("Validation Results - Epoch: {}".format( engine.state.epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1])) trainer.run(train_loader, max_epochs=epochs)
def main( dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True, ) test_loader = data.DataLoader( test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False, ) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction="none") else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction="none") return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, "glow", save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, "optimizer": optimizer }, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss") evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: ( x["total_loss"], torch.empty(x["total_loss"].shape[0]), ), ).attach(evaluator, "total_loss") if y_condition: monitoring_metrics.extend(["nll"]) RunningAverage(output_transform=lambda x: x["nll"]).attach( trainer, "nll") # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])), ).attach(evaluator, "nll") pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split("_")[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ", ".join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f"Validation Results - Epoch: {engine.state.epoch} {losses}") timer = Timer(average=True) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() trainer.run(train_loader, epochs)
def run(cfg, train_loader, tr_comp, saver, trainer, evaler, tb_log=None): # saver.checkpoint_params = {'trainer': trainer, # 'model': tr_comp.model} # 'optimizer': tr_comp.optimizer, # 'center_param': tr_comp.loss_center, # 'optimizer_center': tr_comp.optimizer_center} trainer.add_event_handler( Events.EPOCH_COMPLETED(every=cfg.SAVER.CHECKPOINT_PERIOD), saver.train_checkpointer, saver.checkpoint_params) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer names = ["Acc", "Loss"] names.extend(tr_comp.loss.loss_function_map.keys()) for n in names: RunningAverage(output_transform=Run(n)).attach(trainer, n) # TODO start epoch @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = 0 @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): tr_comp.loss.scheduler_step() tr_comp.scheduler.step() @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD)) def log_training_loss(engine): message = f"Epoch[{engine.state.epoch}], " + \ f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \ f"Base Lr: {tr_comp.scheduler.get_last_lr()[0]:.2e}, " if tr_comp.loss.xent and tr_comp.loss.xent.learning_weight: message += f"xentWeight: {tr_comp.loss.xent.uncertainty.mean().item():.4f}, " for loss_name in engine.state.metrics.keys(): message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, " logger.info(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 80) timer.reset() @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD), saver=saver) def log_validation_results(engine, saver): logger.info(f"Valid - Epoch: {engine.state.epoch}") sum_result = evaler.eval_multi_dataset() if saver.best_result < sum_result: logger.info(f'Save best: {sum_result:.4f}') saver.save_best_value(sum_result) saver.best_checkpointer(engine, saver.checkpoint_params) saver.best_result = sum_result else: logger.info( f"Not best: {saver.best_result:.4f} > {sum_result:.4f}") logger.info('-' * 80) if tb_log: tb_log.attach_handler(trainer, tr_comp.model, tr_comp.optimizer) # self.tb_logger.attach( # validation_evaluator, # log_handler=ReIDOutputHandler(tag="valid", metric_names=["r1_mAP"], another_engine=trainer), # event_name=Events.EPOCH_COMPLETED, # ) trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS) if tb_log: tb_log.close()