def main(args: Optional[argparse.Namespace] = None) -> None: """Trains a model.""" time_start = time.time() if args is None: parser = argparse.ArgumentParser() add_parser_arguments(parser) args = parser.parse_args() if not os.path.exists(args.json): save_default_train_options(args.json) return with open(args.json, 'r') as fi: train_options = json.load(fi) args.__dict__.update(train_options) add_logging_file_handler(Path(args.path_save_dir, 'train_model.log')) logger.info('Started training at: %s', datetime.datetime.now()) set_seeds(args.seed) log_training_options(vars(args)) path_model = os.path.join(args.path_save_dir, 'model.p') model = fnet.models.load_or_init_model(path_model, args.json) init_cuda(args.gpu_ids[0]) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_save_dir, 'losses.csv') if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info('History loaded from: {:s}'.format(path_losses_csv)) else: fnetlogger = fnet.FnetLogger( columns=['num_iter', 'loss_train', 'loss_val']) bpds_train = get_bpds_train(args) bpds_val = get_bpds_val(args) for idx_iter in range(model.count_iter, args.n_iter): x_batch, y_batch = bpds_train.get_batch(args.batch_size) do_save = ((idx_iter + 1) % args.interval_save == 0) or \ ((idx_iter + 1) == args.n_iter) loss_train = model.train_on_batch(x_batch, y_batch) loss_val = None if do_save and bpds_val is not None: loss_val = model.test_on_iterator( [bpds_val.get_batch(args.batch_size) for _ in range(4)]) fnetlogger.add({ 'num_iter': idx_iter + 1, 'loss_train': loss_train, 'loss_val': loss_val, }) print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}') if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info( 'BufferedPatchDataset buffer history: %s', bpds_train.get_buffer_history(), ) logger.info('loss log saved to: {:s}'.format(path_losses_csv)) logger.info('model saved to: {:s}'.format(path_model)) logger.info('elapsed time: {:.1f} s'.format(time.time() - time_start)) if ((idx_iter + 1) in args.iter_checkpoint) or \ ((idx_iter + 1) % args.interval_checkpoint == 0): path_checkpoint = os.path.join( args.path_save_dir, 'checkpoints', 'model_{:06d}.p'.format(idx_iter + 1), ) model.save(path_checkpoint) logger.info('Saved model checkpoint: %s', path_checkpoint) vu.plot_loss( args.path_save_dir, path_save=os.path.join(args.path_save_dir, 'loss_curves.png'), )
def main(): parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int, default=24, help='size of each batch') parser.add_argument('--bpds_kwargs', type=json.loads, default={}, help='kwargs to be passed to BufferedPatchDataset') parser.add_argument('--dataset_class', default='fnet.data.CziDataset', help='Dataset class') parser.add_argument('--dataset_kwargs', type=json.loads, default={}, help='kwargs to be passed to Dataset class') parser.add_argument('--fnet_model_class', default='fnet.models.Model', help='FnetModel class') parser.add_argument('--fnet_model_kwargs', type=json.loads, default={}, help='kwargs to be passed to fnet model class') parser.add_argument('--gpu_ids', type=int, nargs='+', default=0, help='GPU ID') parser.add_argument('--interval_checkpoint', type=int, default=50000, help='intervals at which to save checkpoints of model') parser.add_argument('--interval_save', type=int, default=500, help='iterations between saving log/model') parser.add_argument( '--iter_checkpoint', nargs='+', type=int, default=[], help='iterations at which to save checkpoints of model') parser.add_argument('--n_iter', type=int, default=50000, help='number of training iterations') parser.add_argument('--path_dataset_csv', type=str, help='path to csv for constructing Dataset') parser.add_argument( '--path_dataset_val_csv', type=str, help= 'path to csv for constructing validation Dataset (evaluated everytime the model is saved)' ) parser.add_argument('--path_run_dir', default='saved_models', help='base directory for saved models') parser.add_argument('--seed', type=int, help='random seed') args = parser.parse_args() time_start = time.time() if not os.path.exists(args.path_run_dir): os.makedirs(args.path_run_dir) if len(args.iter_checkpoint) > 0 or args.interval_checkpoint is not None: path_checkpoint_dir = os.path.join(args.path_run_dir, 'checkpoints') if not os.path.exists(path_checkpoint_dir): os.makedirs(path_checkpoint_dir) path_options = os.path.join(args.path_run_dir, 'train_options.json') with open(path_options, 'w') as fo: json.dump(vars(args), fo, indent=4, sort_keys=True) # Setup logging logger = logging.getLogger('model training') logger.setLevel(logging.DEBUG) fh = logging.FileHandler(os.path.join(args.path_run_dir, 'run.log'), mode='a') sh = logging.StreamHandler(sys.stdout) fh.setFormatter(logging.Formatter('%(asctime)s - %(message)s')) logger.addHandler(fh) logger.addHandler(sh) # Set random seed if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Instantiate Model path_model = os.path.join(args.path_run_dir, 'model.p') model = fnet.models.load_or_init_model(path_model, path_options) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_run_dir, 'losses.csv') if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info('History loaded from: {:s}'.format(path_losses_csv)) else: fnetlogger = fnet.FnetLogger( columns=['num_iter', 'loss_batch', 'loss_val']) n_remaining_iterations = max(0, (args.n_iter - model.count_iter)) dataloader_train = get_dataloader(args, n_remaining_iterations) dataloader_val = get_dataloader(args, n_remaining_iterations, validation=True) for i, (signal, target) in enumerate(dataloader_train, model.count_iter): do_save = ((i + 1) % args.interval_save == 0) or \ ((i + 1) == args.n_iter) loss_batch = model.train_on_batch(signal, target) loss_val = get_loss_val(model, dataloader_val) if do_save else None fnetlogger.add({ 'num_iter': i + 1, 'loss_batch': loss_batch, 'loss_val': loss_val }) print('num_iter: {:6d} | loss_batch: {:.3f} | loss_val: {}'.format( i + 1, loss_batch, loss_val)) if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info('BufferedPatchDataset buffer history: {}'.format( dataloader_train.dataset.get_buffer_history())) logger.info('loss log saved to: {:s}'.format(path_losses_csv)) logger.info('model saved to: {:s}'.format(path_model)) logger.info('elapsed time: {:.1f} s'.format(time.time() - time_start)) if ((i + 1) in args.iter_checkpoint) or \ ((i + 1) % args.interval_checkpoint == 0): path_save_checkpoint = os.path.join(path_checkpoint_dir, 'model_{:06d}.p'.format(i + 1)) model.save(path_save_checkpoint) logger.info( 'model checkpoint saved to: {:s}'.format(path_save_checkpoint))
def main(args: Optional[argparse.Namespace] = None): """Trains a model.""" time_start = time.time() if args is None: parser = argparse.ArgumentParser() add_parser_arguments(parser) args = parser.parse_args() args.path_json = Path(args.json) if args.path_json and not args.path_json.exists(): save_default_train_options(args.path_json) return with open(args.path_json, "r") as fi: train_options = json.load(fi) args.__dict__.update(train_options) add_logging_file_handler(Path(args.path_save_dir, "train_model.log")) logger.info(f"Started training at: {datetime.datetime.now()}") set_seeds(args.seed) log_training_options(vars(args)) path_model = os.path.join(args.path_save_dir, "model.p") model = fnet.models.load_or_init_model(path_model, args.path_json) init_cuda(args.gpu_ids[0]) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_save_dir, "losses.csv") if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info(f"History loaded from: {path_losses_csv}") else: fnetlogger = fnet.FnetLogger( columns=["num_iter", "loss_train", "loss_val"]) if (args.n_iter - model.count_iter) <= 0: # Stop if no more iterations needed return # Get patch pair providers bpds_train = get_bpds_train(args) bpds_val = get_bpds_val(args) # MAIN LOOP for idx_iter in range(model.count_iter, args.n_iter): do_save = ((idx_iter + 1) % args.interval_save == 0) or ((idx_iter + 1) == args.n_iter) loss_train = model.train_on_batch( *bpds_train.get_batch(args.batch_size)) loss_val = None if do_save and bpds_val is not None: loss_val = model.test_on_iterator( [bpds_val.get_batch(args.batch_size) for _ in range(4)]) fnetlogger.add({ "num_iter": idx_iter + 1, "loss_train": loss_train, "loss_val": loss_val }) print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}') if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info( "BufferedPatchDataset buffer history: %s", bpds_train.get_buffer_history(), ) logger.info(f"Loss log saved to: {path_losses_csv}") logger.info(f"Model saved to: {path_model}") logger.info(f"Elapsed time: {time.time() - time_start:.1f} s") if ((idx_iter + 1) in args.iter_checkpoint) or ( (idx_iter + 1) % args.interval_checkpoint == 0): path_checkpoint = os.path.join( args.path_save_dir, "checkpoints", "model_{:06d}.p".format(idx_iter + 1)) model.save(path_checkpoint) logger.info(f"Saved model checkpoint: {path_checkpoint}") vu.plot_loss( args.path_save_dir, path_save=os.path.join(args.path_save_dir, "loss_curves.png"), ) return model
def main(args: Optional[argparse.Namespace] = None) -> None: """Trains a model.""" time_start = time.time() if args is None: parser = argparse.ArgumentParser() add_parser_arguments(parser) args = parser.parse_args() if not os.path.exists(args.json): save_default_train_options(args.json) return with open(args.json, 'r') as fi: train_options = json.load(fi) args.__dict__.update(train_options) print('*** Training options ***') pprint.pprint(vars(args)) # Make checkpoint directory if necessary if args.iter_checkpoint or args.interval_checkpoint: path_checkpoint_dir = os.path.join(args.path_save_dir, 'checkpoints') if not os.path.exists(path_checkpoint_dir): os.makedirs(path_checkpoint_dir) logger = init_logger(path_save=os.path.join(args.path_save_dir, 'run.log')) # Set random seed if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Instantiate Model path_model = os.path.join(args.path_save_dir, 'model.p') model = fnet.models.load_or_init_model(path_model, args.json) init_cuda(args.gpu_ids[0]) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_save_dir, 'losses.csv') if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info('History loaded from: {:s}'.format(path_losses_csv)) else: fnetlogger = fnet.FnetLogger( columns=['num_iter', 'loss_train', 'loss_val']) n_remaining_iterations = max(0, (args.n_iter - model.count_iter)) dataloader_train = get_dataloaders(args, n_remaining_iterations) dataloader_val = get_dataloaders( args, n_remaining_iterations, validation=True, ) for idx_iter, (x_batch, y_batch) in enumerate(dataloader_train, model.count_iter): do_save = ((idx_iter + 1) % args.interval_save == 0) or \ ((idx_iter + 1) == args.n_iter) loss_train = model.train_on_batch(x_batch, y_batch) loss_val = None if do_save and dataloader_val is not None: loss_val = model.test_on_iterator(dataloader_val) fnetlogger.add({ 'num_iter': idx_iter + 1, 'loss_train': loss_train, 'loss_val': loss_val, }) print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}') if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info( 'BufferedPatchDataset buffer history: %s', dataloader_train.dataset.get_buffer_history(), ) logger.info('loss log saved to: {:s}'.format(path_losses_csv)) logger.info('model saved to: {:s}'.format(path_model)) logger.info('elapsed time: {:.1f} s'.format(time.time() - time_start)) if ((idx_iter + 1) in args.iter_checkpoint) or \ ((idx_iter + 1) % args.interval_checkpoint == 0): path_save_checkpoint = os.path.join( path_checkpoint_dir, 'model_{:06d}.p'.format(idx_iter + 1)) model.save(path_save_checkpoint) logger.info('Saved model checkpoint: %s', path_save_checkpoint) vu.plot_loss( args.path_save_dir, path_save=os.path.join(args.path_save_dir, 'loss_curves.png'), )