def main(args: Optional[argparse.Namespace] = None) -> None:
    """Trains a model."""
    time_start = time.time()
    if args is None:
        parser = argparse.ArgumentParser()
        add_parser_arguments(parser)
        args = parser.parse_args()
    if not os.path.exists(args.json):
        save_default_train_options(args.json)
        return
    with open(args.json, 'r') as fi:
        train_options = json.load(fi)
    args.__dict__.update(train_options)
    add_logging_file_handler(Path(args.path_save_dir, 'train_model.log'))
    logger.info('Started training at: %s', datetime.datetime.now())

    set_seeds(args.seed)
    log_training_options(vars(args))
    path_model = os.path.join(args.path_save_dir, 'model.p')
    model = fnet.models.load_or_init_model(path_model, args.json)
    init_cuda(args.gpu_ids[0])
    model.to_gpu(args.gpu_ids)
    logger.info(model)

    path_losses_csv = os.path.join(args.path_save_dir, 'losses.csv')
    if os.path.exists(path_losses_csv):
        fnetlogger = fnet.FnetLogger(path_losses_csv)
        logger.info('History loaded from: {:s}'.format(path_losses_csv))
    else:
        fnetlogger = fnet.FnetLogger(
            columns=['num_iter', 'loss_train', 'loss_val'])

    bpds_train = get_bpds_train(args)
    bpds_val = get_bpds_val(args)

    for idx_iter in range(model.count_iter, args.n_iter):
        x_batch, y_batch = bpds_train.get_batch(args.batch_size)
        do_save = ((idx_iter + 1) % args.interval_save == 0) or \
                  ((idx_iter + 1) == args.n_iter)
        loss_train = model.train_on_batch(x_batch, y_batch)
        loss_val = None
        if do_save and bpds_val is not None:
            loss_val = model.test_on_iterator(
                [bpds_val.get_batch(args.batch_size) for _ in range(4)])
        fnetlogger.add({
            'num_iter': idx_iter + 1,
            'loss_train': loss_train,
            'loss_val': loss_val,
        })
        print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | '
              f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}')
        if do_save:
            model.save(path_model)
            fnetlogger.to_csv(path_losses_csv)
            logger.info(
                'BufferedPatchDataset buffer history: %s',
                bpds_train.get_buffer_history(),
            )
            logger.info('loss log saved to: {:s}'.format(path_losses_csv))
            logger.info('model saved to: {:s}'.format(path_model))
            logger.info('elapsed time: {:.1f} s'.format(time.time() -
                                                        time_start))
        if ((idx_iter + 1) in args.iter_checkpoint) or \
           ((idx_iter + 1) % args.interval_checkpoint == 0):
            path_checkpoint = os.path.join(
                args.path_save_dir,
                'checkpoints',
                'model_{:06d}.p'.format(idx_iter + 1),
            )
            model.save(path_checkpoint)
            logger.info('Saved model checkpoint: %s', path_checkpoint)
            vu.plot_loss(
                args.path_save_dir,
                path_save=os.path.join(args.path_save_dir, 'loss_curves.png'),
            )
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size',
                        type=int,
                        default=24,
                        help='size of each batch')
    parser.add_argument('--bpds_kwargs',
                        type=json.loads,
                        default={},
                        help='kwargs to be passed to BufferedPatchDataset')
    parser.add_argument('--dataset_class',
                        default='fnet.data.CziDataset',
                        help='Dataset class')
    parser.add_argument('--dataset_kwargs',
                        type=json.loads,
                        default={},
                        help='kwargs to be passed to Dataset class')
    parser.add_argument('--fnet_model_class',
                        default='fnet.models.Model',
                        help='FnetModel class')
    parser.add_argument('--fnet_model_kwargs',
                        type=json.loads,
                        default={},
                        help='kwargs to be passed to fnet model class')
    parser.add_argument('--gpu_ids',
                        type=int,
                        nargs='+',
                        default=0,
                        help='GPU ID')
    parser.add_argument('--interval_checkpoint',
                        type=int,
                        default=50000,
                        help='intervals at which to save checkpoints of model')
    parser.add_argument('--interval_save',
                        type=int,
                        default=500,
                        help='iterations between saving log/model')
    parser.add_argument(
        '--iter_checkpoint',
        nargs='+',
        type=int,
        default=[],
        help='iterations at which to save checkpoints of model')
    parser.add_argument('--n_iter',
                        type=int,
                        default=50000,
                        help='number of training iterations')
    parser.add_argument('--path_dataset_csv',
                        type=str,
                        help='path to csv for constructing Dataset')
    parser.add_argument(
        '--path_dataset_val_csv',
        type=str,
        help=
        'path to csv for constructing validation Dataset (evaluated everytime the model is saved)'
    )
    parser.add_argument('--path_run_dir',
                        default='saved_models',
                        help='base directory for saved models')
    parser.add_argument('--seed', type=int, help='random seed')
    args = parser.parse_args()

    time_start = time.time()
    if not os.path.exists(args.path_run_dir):
        os.makedirs(args.path_run_dir)
    if len(args.iter_checkpoint) > 0 or args.interval_checkpoint is not None:
        path_checkpoint_dir = os.path.join(args.path_run_dir, 'checkpoints')
        if not os.path.exists(path_checkpoint_dir):
            os.makedirs(path_checkpoint_dir)

    path_options = os.path.join(args.path_run_dir, 'train_options.json')
    with open(path_options, 'w') as fo:
        json.dump(vars(args), fo, indent=4, sort_keys=True)

    # Setup logging
    logger = logging.getLogger('model training')
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(os.path.join(args.path_run_dir, 'run.log'),
                             mode='a')
    sh = logging.StreamHandler(sys.stdout)
    fh.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
    logger.addHandler(fh)
    logger.addHandler(sh)

    # Set random seed
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # Instantiate Model
    path_model = os.path.join(args.path_run_dir, 'model.p')
    model = fnet.models.load_or_init_model(path_model, path_options)
    model.to_gpu(args.gpu_ids)
    logger.info(model)

    path_losses_csv = os.path.join(args.path_run_dir, 'losses.csv')
    if os.path.exists(path_losses_csv):
        fnetlogger = fnet.FnetLogger(path_losses_csv)
        logger.info('History loaded from: {:s}'.format(path_losses_csv))
    else:
        fnetlogger = fnet.FnetLogger(
            columns=['num_iter', 'loss_batch', 'loss_val'])

    n_remaining_iterations = max(0, (args.n_iter - model.count_iter))
    dataloader_train = get_dataloader(args, n_remaining_iterations)
    dataloader_val = get_dataloader(args,
                                    n_remaining_iterations,
                                    validation=True)
    for i, (signal, target) in enumerate(dataloader_train, model.count_iter):
        do_save = ((i + 1) % args.interval_save == 0) or \
                  ((i + 1) == args.n_iter)
        loss_batch = model.train_on_batch(signal, target)
        loss_val = get_loss_val(model, dataloader_val) if do_save else None
        fnetlogger.add({
            'num_iter': i + 1,
            'loss_batch': loss_batch,
            'loss_val': loss_val
        })
        print('num_iter: {:6d} | loss_batch: {:.3f} | loss_val: {}'.format(
            i + 1, loss_batch, loss_val))
        if do_save:
            model.save(path_model)
            fnetlogger.to_csv(path_losses_csv)
            logger.info('BufferedPatchDataset buffer history: {}'.format(
                dataloader_train.dataset.get_buffer_history()))
            logger.info('loss log saved to: {:s}'.format(path_losses_csv))
            logger.info('model saved to: {:s}'.format(path_model))
            logger.info('elapsed time: {:.1f} s'.format(time.time() -
                                                        time_start))
        if ((i + 1) in args.iter_checkpoint) or \
           ((i + 1) % args.interval_checkpoint == 0):
            path_save_checkpoint = os.path.join(path_checkpoint_dir,
                                                'model_{:06d}.p'.format(i + 1))
            model.save(path_save_checkpoint)
            logger.info(
                'model checkpoint saved to: {:s}'.format(path_save_checkpoint))
예제 #3
0
def main(args: Optional[argparse.Namespace] = None):
    """Trains a model."""
    time_start = time.time()

    if args is None:
        parser = argparse.ArgumentParser()
        add_parser_arguments(parser)
        args = parser.parse_args()

    args.path_json = Path(args.json)

    if args.path_json and not args.path_json.exists():
        save_default_train_options(args.path_json)
        return

    with open(args.path_json, "r") as fi:
        train_options = json.load(fi)

    args.__dict__.update(train_options)
    add_logging_file_handler(Path(args.path_save_dir, "train_model.log"))
    logger.info(f"Started training at: {datetime.datetime.now()}")

    set_seeds(args.seed)
    log_training_options(vars(args))
    path_model = os.path.join(args.path_save_dir, "model.p")
    model = fnet.models.load_or_init_model(path_model, args.path_json)
    init_cuda(args.gpu_ids[0])
    model.to_gpu(args.gpu_ids)
    logger.info(model)

    path_losses_csv = os.path.join(args.path_save_dir, "losses.csv")
    if os.path.exists(path_losses_csv):
        fnetlogger = fnet.FnetLogger(path_losses_csv)
        logger.info(f"History loaded from: {path_losses_csv}")
    else:
        fnetlogger = fnet.FnetLogger(
            columns=["num_iter", "loss_train", "loss_val"])

    if (args.n_iter - model.count_iter) <= 0:
        # Stop if no more iterations needed
        return

    # Get patch pair providers
    bpds_train = get_bpds_train(args)
    bpds_val = get_bpds_val(args)

    # MAIN LOOP
    for idx_iter in range(model.count_iter, args.n_iter):
        do_save = ((idx_iter + 1) % args.interval_save
                   == 0) or ((idx_iter + 1) == args.n_iter)
        loss_train = model.train_on_batch(
            *bpds_train.get_batch(args.batch_size))
        loss_val = None
        if do_save and bpds_val is not None:
            loss_val = model.test_on_iterator(
                [bpds_val.get_batch(args.batch_size) for _ in range(4)])
        fnetlogger.add({
            "num_iter": idx_iter + 1,
            "loss_train": loss_train,
            "loss_val": loss_val
        })
        print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | '
              f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}')
        if do_save:
            model.save(path_model)
            fnetlogger.to_csv(path_losses_csv)
            logger.info(
                "BufferedPatchDataset buffer history: %s",
                bpds_train.get_buffer_history(),
            )
            logger.info(f"Loss log saved to: {path_losses_csv}")
            logger.info(f"Model saved to: {path_model}")
            logger.info(f"Elapsed time: {time.time() - time_start:.1f} s")
        if ((idx_iter + 1) in args.iter_checkpoint) or (
            (idx_iter + 1) % args.interval_checkpoint == 0):
            path_checkpoint = os.path.join(
                args.path_save_dir, "checkpoints",
                "model_{:06d}.p".format(idx_iter + 1))
            model.save(path_checkpoint)
            logger.info(f"Saved model checkpoint: {path_checkpoint}")
            vu.plot_loss(
                args.path_save_dir,
                path_save=os.path.join(args.path_save_dir, "loss_curves.png"),
            )

    return model
예제 #4
0
def main(args: Optional[argparse.Namespace] = None) -> None:
    """Trains a model."""
    time_start = time.time()
    if args is None:
        parser = argparse.ArgumentParser()
        add_parser_arguments(parser)
        args = parser.parse_args()
    if not os.path.exists(args.json):
        save_default_train_options(args.json)
        return
    with open(args.json, 'r') as fi:
        train_options = json.load(fi)
    args.__dict__.update(train_options)
    print('*** Training options ***')
    pprint.pprint(vars(args))

    # Make checkpoint directory if necessary
    if args.iter_checkpoint or args.interval_checkpoint:
        path_checkpoint_dir = os.path.join(args.path_save_dir, 'checkpoints')
        if not os.path.exists(path_checkpoint_dir):
            os.makedirs(path_checkpoint_dir)
    logger = init_logger(path_save=os.path.join(args.path_save_dir, 'run.log'))

    # Set random seed
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # Instantiate Model
    path_model = os.path.join(args.path_save_dir, 'model.p')
    model = fnet.models.load_or_init_model(path_model, args.json)
    init_cuda(args.gpu_ids[0])
    model.to_gpu(args.gpu_ids)
    logger.info(model)

    path_losses_csv = os.path.join(args.path_save_dir, 'losses.csv')
    if os.path.exists(path_losses_csv):
        fnetlogger = fnet.FnetLogger(path_losses_csv)
        logger.info('History loaded from: {:s}'.format(path_losses_csv))
    else:
        fnetlogger = fnet.FnetLogger(
            columns=['num_iter', 'loss_train', 'loss_val'])

    n_remaining_iterations = max(0, (args.n_iter - model.count_iter))
    dataloader_train = get_dataloaders(args, n_remaining_iterations)
    dataloader_val = get_dataloaders(
        args,
        n_remaining_iterations,
        validation=True,
    )
    for idx_iter, (x_batch, y_batch) in enumerate(dataloader_train,
                                                  model.count_iter):
        do_save = ((idx_iter + 1) % args.interval_save == 0) or \
                  ((idx_iter + 1) == args.n_iter)
        loss_train = model.train_on_batch(x_batch, y_batch)
        loss_val = None
        if do_save and dataloader_val is not None:
            loss_val = model.test_on_iterator(dataloader_val)
        fnetlogger.add({
            'num_iter': idx_iter + 1,
            'loss_train': loss_train,
            'loss_val': loss_val,
        })
        print(f'iter: {fnetlogger.data["num_iter"][-1]:6d} | '
              f'loss_train: {fnetlogger.data["loss_train"][-1]:.4f}')
        if do_save:
            model.save(path_model)
            fnetlogger.to_csv(path_losses_csv)
            logger.info(
                'BufferedPatchDataset buffer history: %s',
                dataloader_train.dataset.get_buffer_history(),
            )
            logger.info('loss log saved to: {:s}'.format(path_losses_csv))
            logger.info('model saved to: {:s}'.format(path_model))
            logger.info('elapsed time: {:.1f} s'.format(time.time() -
                                                        time_start))
        if ((idx_iter + 1) in args.iter_checkpoint) or \
           ((idx_iter + 1) % args.interval_checkpoint == 0):
            path_save_checkpoint = os.path.join(
                path_checkpoint_dir, 'model_{:06d}.p'.format(idx_iter + 1))
            model.save(path_save_checkpoint)
            logger.info('Saved model checkpoint: %s', path_save_checkpoint)
            vu.plot_loss(
                args.path_save_dir,
                path_save=os.path.join(args.path_save_dir, 'loss_curves.png'),
            )