] # this is the DNA channel for all FOVs n_train_images = int(n_images_to_download * train_fraction) df_train = df[:n_train_images] df_test = df[n_train_images:] df_test.to_csv(data_save_path_test, index=False) df_train.to_csv(data_save_path_train, index=False) ################################################ # Run the label-free stuff (dont change this) ################################################ prefs_save_path = Path(prefs_save_path) save_default_train_options(prefs_save_path) with open(prefs_save_path, "r") as fp: prefs = json.load(fp) prefs["n_iter"] = 50000 # takes about 16 hours, go up to 250,000 for full training prefs["interval_checkpoint"] = 10000 prefs["dataset_train"] = "fnet.data.MultiChTiffDataset" prefs["dataset_train_kwargs"] = {"path_csv": data_save_path_train} prefs["dataset_val"] = "fnet.data.MultiChTiffDataset" prefs["dataset_val_kwargs"] = {"path_csv": data_save_path_test} # This Fnet call will be updated as a python API becomes available with open(prefs_save_path, "w") as fp:
def main(args: Optional[argparse.Namespace] = None): """Trains a model.""" time_start = time.time() if args is None: parser = argparse.ArgumentParser() add_parser_arguments(parser) args = parser.parse_args() args.path_json = Path(args.json) if args.path_json and not args.path_json.exists(): save_default_train_options(args.path_json) return with open(args.path_json, "r") as fi: train_options = json.load(fi) args.__dict__.update(train_options) add_logging_file_handler(Path(args.path_save_dir, "train_model.log")) logger.info(f"Started training at: {datetime.datetime.now()}") set_seeds(args.seed) log_training_options(vars(args)) path_model = os.path.join(args.path_save_dir, "model.p") model = fnet.models.load_or_init_model(path_model, args.path_json) init_cuda(args.gpu_ids[0]) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_save_dir, "losses.csv") if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info(f"History loaded from: {path_losses_csv}") else: fnetlogger = fnet.FnetLogger( columns=["num_iter", "loss_train", "loss_val"]) if (args.n_iter - model.count_iter) <= 0: # Stop if no more iterations needed return # Get patch pair providers bpds_train = get_bpds_train(args) bpds_val = get_bpds_val(args) # MAIN LOOP for idx_iter in range(model.count_iter, args.n_iter): do_save = ((idx_iter + 1) % args.interval_save == 0) or ((idx_iter + 1) == args.n_iter) loss_train = model.train_on_batch( *bpds_train.get_batch(args.batch_size)) loss_val = None if do_save and bpds_val is not None: loss_val = model.test_on_iterator( [bpds_val.get_batch(args.batch_size) for _ in range(4)]) fnetlogger.add({ "num_iter": idx_iter + 1, "loss_train": loss_train, "loss_val": loss_val }) print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}') if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info( "BufferedPatchDataset buffer history: %s", bpds_train.get_buffer_history(), ) logger.info(f"Loss log saved to: {path_losses_csv}") logger.info(f"Model saved to: {path_model}") logger.info(f"Elapsed time: {time.time() - time_start:.1f} s") if ((idx_iter + 1) in args.iter_checkpoint) or ( (idx_iter + 1) % args.interval_checkpoint == 0): path_checkpoint = os.path.join( args.path_save_dir, "checkpoints", "model_{:06d}.p".format(idx_iter + 1)) model.save(path_checkpoint) logger.info(f"Saved model checkpoint: {path_checkpoint}") vu.plot_loss( args.path_save_dir, path_save=os.path.join(args.path_save_dir, "loss_curves.png"), ) return model
def main(args: Optional[argparse.Namespace] = None) -> None: """Trains a model.""" time_start = time.time() if args is None: parser = argparse.ArgumentParser() add_parser_arguments(parser) args = parser.parse_args() if not os.path.exists(args.json): save_default_train_options(args.json) return with open(args.json, 'r') as fi: train_options = json.load(fi) args.__dict__.update(train_options) add_logging_file_handler(Path(args.path_save_dir, 'train_model.log')) logger.info('Started training at: %s', datetime.datetime.now()) set_seeds(args.seed) log_training_options(vars(args)) path_model = os.path.join(args.path_save_dir, 'model.p') model = fnet.models.load_or_init_model(path_model, args.json) init_cuda(args.gpu_ids[0]) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_save_dir, 'losses.csv') if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info('History loaded from: {:s}'.format(path_losses_csv)) else: fnetlogger = fnet.FnetLogger( columns=['num_iter', 'loss_train', 'loss_val']) bpds_train = get_bpds_train(args) bpds_val = get_bpds_val(args) for idx_iter in range(model.count_iter, args.n_iter): x_batch, y_batch = bpds_train.get_batch(args.batch_size) do_save = ((idx_iter + 1) % args.interval_save == 0) or \ ((idx_iter + 1) == args.n_iter) loss_train = model.train_on_batch(x_batch, y_batch) loss_val = None if do_save and bpds_val is not None: loss_val = model.test_on_iterator( [bpds_val.get_batch(args.batch_size) for _ in range(4)]) fnetlogger.add({ 'num_iter': idx_iter + 1, 'loss_train': loss_train, 'loss_val': loss_val, }) print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}') if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info( 'BufferedPatchDataset buffer history: %s', bpds_train.get_buffer_history(), ) logger.info('loss log saved to: {:s}'.format(path_losses_csv)) logger.info('model saved to: {:s}'.format(path_model)) logger.info('elapsed time: {:.1f} s'.format(time.time() - time_start)) if ((idx_iter + 1) in args.iter_checkpoint) or \ ((idx_iter + 1) % args.interval_checkpoint == 0): path_checkpoint = os.path.join( args.path_save_dir, 'checkpoints', 'model_{:06d}.p'.format(idx_iter + 1), ) model.save(path_checkpoint) logger.info('Saved model checkpoint: %s', path_checkpoint) vu.plot_loss( args.path_save_dir, path_save=os.path.join(args.path_save_dir, 'loss_curves.png'), )
def main(args: Optional[argparse.Namespace] = None) -> None: """Trains a model.""" time_start = time.time() if args is None: parser = argparse.ArgumentParser() add_parser_arguments(parser) args = parser.parse_args() if not os.path.exists(args.json): save_default_train_options(args.json) return with open(args.json, 'r') as fi: train_options = json.load(fi) args.__dict__.update(train_options) print('*** Training options ***') pprint.pprint(vars(args)) # Make checkpoint directory if necessary if args.iter_checkpoint or args.interval_checkpoint: path_checkpoint_dir = os.path.join(args.path_save_dir, 'checkpoints') if not os.path.exists(path_checkpoint_dir): os.makedirs(path_checkpoint_dir) logger = init_logger(path_save=os.path.join(args.path_save_dir, 'run.log')) # Set random seed if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Instantiate Model path_model = os.path.join(args.path_save_dir, 'model.p') model = fnet.models.load_or_init_model(path_model, args.json) init_cuda(args.gpu_ids[0]) model.to_gpu(args.gpu_ids) logger.info(model) path_losses_csv = os.path.join(args.path_save_dir, 'losses.csv') if os.path.exists(path_losses_csv): fnetlogger = fnet.FnetLogger(path_losses_csv) logger.info('History loaded from: {:s}'.format(path_losses_csv)) else: fnetlogger = fnet.FnetLogger( columns=['num_iter', 'loss_train', 'loss_val']) n_remaining_iterations = max(0, (args.n_iter - model.count_iter)) dataloader_train = get_dataloaders(args, n_remaining_iterations) dataloader_val = get_dataloaders( args, n_remaining_iterations, validation=True, ) for idx_iter, (x_batch, y_batch) in enumerate(dataloader_train, model.count_iter): do_save = ((idx_iter + 1) % args.interval_save == 0) or \ ((idx_iter + 1) == args.n_iter) loss_train = model.train_on_batch(x_batch, y_batch) loss_val = None if do_save and dataloader_val is not None: loss_val = model.test_on_iterator(dataloader_val) fnetlogger.add({ 'num_iter': idx_iter + 1, 'loss_train': loss_train, 'loss_val': loss_val, }) print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | ' f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}') if do_save: model.save(path_model) fnetlogger.to_csv(path_losses_csv) logger.info( 'BufferedPatchDataset buffer history: %s', dataloader_train.dataset.get_buffer_history(), ) logger.info('loss log saved to: {:s}'.format(path_losses_csv)) logger.info('model saved to: {:s}'.format(path_model)) logger.info('elapsed time: {:.1f} s'.format(time.time() - time_start)) if ((idx_iter + 1) in args.iter_checkpoint) or \ ((idx_iter + 1) % args.interval_checkpoint == 0): path_save_checkpoint = os.path.join( path_checkpoint_dir, 'model_{:06d}.p'.format(idx_iter + 1)) model.save(path_save_checkpoint) logger.info('Saved model checkpoint: %s', path_save_checkpoint) vu.plot_loss( args.path_save_dir, path_save=os.path.join(args.path_save_dir, 'loss_curves.png'), )