Beispiel #1
0
def run(config):

    opt_cfg = config["optimizer"]
    data_cfg = config["data"]
    model_cfg = config["model"]

    # Loaders
    batch_size = opt_cfg["batch_size"]
    preproc = loader.Preprocessor(data_cfg["train_set"],
                  start_and_end=data_cfg["start_and_end"])
    train_ldr = loader.make_loader(data_cfg["train_set"],
                        preproc, batch_size)
    dev_ldr = loader.make_loader(data_cfg["dev_set"],
                        preproc, batch_size)

    # Model
    model_class = eval("models." + model_cfg["class"])
    model = model_class(preproc.input_dim,
                        preproc.vocab_size,
                        model_cfg)
    #model.cuda() if use_cuda else model.cpu()
    if use_cuda:
        model.cuda()
    else:
        model.cpu()
    # Optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                    lr=opt_cfg["learning_rate"],
                    momentum=opt_cfg["momentum"])

    run_state = (0, 0)
    best_so_far = float("inf")
    for e in range(opt_cfg["epochs"]):
        start = time.time()

        run_state = run_epoch(model, optimizer, train_ldr, *run_state)

        msg = "Epoch {} completed in {:.2f} (s)."
        print(msg.format(e, time.time() - start))

        dev_loss, dev_cer = eval_dev(model, dev_ldr, preproc)

        # Log for tensorboard
        tb.log_value("dev_loss", dev_loss, e)
        tb.log_value("dev_cer", dev_cer, e)

        speech.save(model, preproc, config["save_path"])

        # Save the best model on the dev set
        if dev_cer < best_so_far:
            best_so_far = dev_cer
            speech.save(model, preproc,
                    config["save_path"], tag="best")
Beispiel #2
0
def run(config, use_cuda):
    opt_cfg = config["optimizer"]
    data_cfg = config["data"]
    model_cfg = config["model"]
    aud_cfg = config['audio']
    batch_size = opt_cfg["batch_size"]

    load_pre = True

    if load_pre:
        # Todo: add code for checking if pretrained actually exists. If not, init model and rest
        model, _, preproc = speech.load("ctc_best", tag="best")
    else:
        preproc = loader.Preprocessor(data_cfg["train_set"], aud_cfg, start_and_end=data_cfg["start_and_end"])
        # eval('print("Hello")') will actually call print("Hello")
        model_class = eval("models." + model_cfg["class"])
        # define model
        model = model_class(preproc.input_dim, preproc.vocab_size, model_cfg)

    model = model.cuda() if use_cuda else model.cpu()
    optimizer = torch.optim.SGD(model.parameters(), lr=opt_cfg["learning_rate"],
                                momentum=opt_cfg["momentum"])
    # Dataloader is a subclass of pytorch.utils.dataloader. Can iterate
    train_ldr = loader.make_loader(data_cfg["train_set"], preproc, batch_size)
    dev_ldr = loader.make_loader(data_cfg["dev_set"], preproc, batch_size)

    print("Epochs to train:", opt_cfg["epochs"])
    run_state = (0, 0)
    best_so_far = float("inf")
    for e in range(opt_cfg["epochs"]):
        start = time.time()

        run_state = run_epoch(model, optimizer, train_ldr, *run_state)

        msg = "Epoch {} completed in {:.2f} (s)."
        print(msg.format(e, time.time() - start))
        if (e % 10 == 0) or (e == (opt_cfg["epochs"] - 1)):
            dev_loss, dev_cer = eval_dev(model, dev_ldr, preproc)

            # Log for tensorboard
            tb.log_value("dev_loss", dev_loss, e)
            tb.log_value("dev_cer", dev_cer, e)

        speech.save(model, optimizer, preproc, config["save_path"])

        # Save the best model on the dev set
        if dev_cer < best_so_far:
            best_so_far = dev_cer
            speech.save(model, optimizer, preproc, config["save_path"], tag="best")
Beispiel #3
0
def run(local_rank: int, config: dict) -> None:
    """Main function that defines the data, optimizer, and model objects and runs the training
    and evaluation loops.

    Args:
        local_rank (int): rank of the process on the GPU
        config (dict): training configuration dict
    """
    # unpacking the config
    data_cfg = config["data"]
    log_cfg = config["logger"]
    preproc_cfg = config["preproc"]
    opt_cfg = config["optimizer"]
    model_cfg = config["model"]
    train_cfg = config['training']
    ckpt_cfg = config['checkpoint']

    gcs_ckpt_handler = GCSCheckpointHandler(ckpt_cfg)

    # save the config to gcs
    os.makedirs(ckpt_cfg['local_save_path'], exist_ok=True)
    with open(os.path.join(ckpt_cfg['local_save_path'], "ctc_config.yaml"),
              'w') as fid:
        yaml.dump(config, fid)
    gcs_ckpt_handler.upload_to_gcs("ctc_config.yaml")

    # setting up the distributed training environment
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(local_rank)
    print(
        f"local_rank: {local_rank}, dist.get_rank: {torch.distributed.get_rank()}"
    )
    is_rank_0 = (torch.distributed.get_rank() == 0)

    # defining the logging and debugging modes
    use_log = log_cfg["use_log"] and is_rank_0
    debug_mode = log_cfg["debug_mode"]
    if debug_mode: torch.autograd.set_detect_anomaly(True)

    # create a logger, rank_0 boolean is contained in `use_log`
    logger = get_logger("train_log", log_cfg['log_file'],
                        log_cfg['level']) if use_log else None

    # creates tensorboardX writer in rank_0 process
    tbX_writer = SummaryWriter(
        logdir=ckpt_cfg["local_save_path"]) if is_rank_0 else None

    # Load previous train state: dict with contents:
    # {start_epoch: int, run_state: (int, float), best_so_far: float, learning_rate: float}
    train_state_path = gcs_ckpt_handler.download_from_gcs_bucket(
        os.path.join(ckpt_cfg['gcs_dir'], "train_state.pickle"))
    if train_state_path:
        print(f"load train_state from: {train_state_path}")
        train_state = read_pickle(train_state_path)
    # if train_path doesn't exist, create empty dict to load from config
    else:
        print(f"load train_state from config")
        train_state = dict()

    # the get-statements will load from train_state if key exists, and from opt_cfg otherwise
    run_state = train_state.get('run_state', opt_cfg['run_state'])
    best_so_far = train_state.get('best_so_far', opt_cfg['best_so_far'])
    start_epoch = train_state.get('start_epoch', opt_cfg['start_epoch'])

    # create the preproc object and data loaders
    batch_size = opt_cfg["batch_size"]
    preproc = loader.Preprocessor(data_cfg["train_set"],
                                  preproc_cfg,
                                  logger,
                                  start_and_end=data_cfg["start_and_end"])

    train_ldr = loader.make_ddp_loader(data_cfg["train_set"],
                                       preproc,
                                       batch_size,
                                       num_workers=data_cfg["num_workers"])

    # create the dev-set loaders in the rank_0 process
    if is_rank_0:
        dev_ldr_dict = dict()
        for dev_name, dev_path in data_cfg["dev_sets"].items():
            dev_ldr = loader.make_loader(dev_path,
                                         preproc,
                                         batch_size=8,
                                         num_workers=data_cfg["num_workers"])
            dev_ldr_dict.update({dev_name: dev_ldr})

    # Model
    # add the blank_idx to model_cfg
    model_cfg.update({'blank_idx': preproc_cfg['blank_idx']})
    model = CTC_train(preproc.input_dim, preproc.vocab_size, model_cfg)

    # load a model from checkpoint, if it exists
    model_ckpt_path = gcs_ckpt_handler.download_from_gcs_bucket(
        os.path.join(ckpt_cfg['gcs_dir'], "ckpt_model_state_dict.pth"))
    if model_ckpt_path:
        model_cfg['local_trained_path'] = model_ckpt_path
        model = load_from_trained(model, model_cfg)
        print(
            f"Succesfully loaded weights from checkpoint: {ckpt_cfg['gcs_dir']}"
        )
    # if a model checkpoint doesn't exist, load from trained if selected and possible
    else:
        if model_cfg["load_trained"]:
            local_trained_path = gcs_ckpt_handler.download_from_gcs_bucket(
                model_cfg['gcs_trained_path'])
            if local_trained_path:
                model_cfg['local_trained_path'] = local_trained_path
                model = load_from_trained(model, model_cfg)
                print(
                    f"Succesfully loaded weights from trained model: {model_cfg['gcs_trained_path']}"
                )
            else:
                print(
                    f"no model found at gcs location: {model_cfg['gcs_trained_path']}"
                )

        else:
            print("model trained from scratch")

    # Optimizer and learning rate scheduler
    learning_rate = opt_cfg['learning_rate']
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=learning_rate,  # from train_state or opt_config
        momentum=opt_cfg["momentum"],
        dampening=opt_cfg["dampening"])

    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=opt_cfg["sched_step"],
        gamma=opt_cfg["sched_gamma"])

    # gradient scaler, too large a value for init_scale produces NaN gradients
    scaler = GradScaler(enabled=train_cfg['amp'], init_scale=16)

    # call the ddp wrappers
    model.cuda(local_rank)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[local_rank],
                                                output_device=local_rank)

    if use_log:
        logger.info(
            f"train: ====== Model, loaders, optimimzer created =======")
        logger.info(f"train: model: {model}")
        logger.info(f"train: preproc: {preproc}")
        logger.info(f"train: optimizer: {optimizer}")
        logger.info(f"train: config: {config}")

    # printing to the output file
    if is_rank_0:
        print(f"====== Model, loaders, optimimzer created =======")
        print(f"model: {model}")
        print(f"preproc: {preproc}")
        print(f"optimizer: {optimizer}")
        print(f"config: {config}")

    # training loop
    for epoch in range(start_epoch, opt_cfg["epochs"]):

        start = time.time()
        for group in optimizer.param_groups:
            if is_rank_0: print(f'learning rate: {group["lr"]}')
            if use_log: logger.info(f"train: learning rate: {group['lr']}")

        try:
            run_state = run_epoch(model, optimizer, train_ldr, logger,
                                  debug_mode, tbX_writer, *run_state,
                                  local_rank, train_cfg['loss_name'],
                                  ckpt_cfg['local_save_path'],
                                  gcs_ckpt_handler, scaler)
        except Exception as err:
            if use_log:
                logger.error(f"Exception raised: {err}")
                logger.error(f"train: ====In except block====")
                logger.error(f"train: state_dict: {model.module.state_dict()}")
                log_model_grads(model.module.named_parameters(), logger)
            raise Exception('Failure in run_epoch').with_traceback(
                err.__traceback__)
        finally:  # used to ensure that plots are closed even if exception raised
            plt.close('all')

        # update the learning rate
        lr_scheduler.step()

        if use_log:
            logger.info(f"train: ====== Run_state finished =======")
            logger.info(f"train: preproc type: {type(preproc)}")
        if is_rank_0:
            msg = "Epoch {} completed in {:.2f} (hr)."
            epoch_time_hr = (time.time() - start) / 60 / 60
            print(msg.format(epoch, epoch_time_hr))
            if use_log: logger.info(msg.format(epoch, epoch_time_hr))
            tbX_writer.add_scalars('train/stats',
                                   {"epoch_time_hr": epoch_time_hr}, epoch)

            # the logger needs to be removed to save the model
            if use_log: preproc.logger = None
            speech.save(model.module, preproc, ckpt_cfg["local_save_path"])
            gcs_ckpt_handler.upload_to_gcs("model_state_dict.pth")
            gcs_ckpt_handler.upload_to_gcs("preproc.pyc")

            if use_log:
                logger.info(f"train: ====== model saved =======")
                preproc.logger = logger

            # creating the dictionaries that hold the PER and loss values
            dev_loss_dict = dict()
            dev_per_dict = dict()
            # iterating through the dev-set loaders to calculate the PER/loss
            for dev_name, dev_ldr in dev_ldr_dict.items():
                print(f"evaluating devset: {dev_name}")
                if use_log:
                    logger.info(f"train: === evaluating devset: {dev_name} ==")
                dev_loss, dev_per = eval_dev(model.module, dev_ldr, preproc,
                                             logger, train_cfg['loss_name'])

                dev_loss_dict.update({dev_name: dev_loss})
                dev_per_dict.update({dev_name: dev_per})

                if use_log:
                    logger.info(
                        f"train: ====== eval_dev {dev_name} finished =======")

                # Save the best model on the dev set
                if dev_name == data_cfg['dev_set_save_reference']:
                    print(
                        f"dev_reference {dev_name}: current PER: {dev_per} vs. best_so_far: {best_so_far}"
                    )

                    if use_log:
                        logger.info(
                            f"dev_reference {dev_name}: current PER: {dev_per} vs. best_so_far: {best_so_far}"
                        )
                    if dev_per < best_so_far:
                        if use_log:
                            preproc.logger = None  # remove the logger to save the model
                        best_so_far = dev_per
                        speech.save(model.module,
                                    preproc,
                                    ckpt_cfg["local_save_path"],
                                    tag="best")
                        gcs_ckpt_handler.upload_to_gcs(
                            "best_model_state_dict.pth")
                        gcs_ckpt_handler.upload_to_gcs("best_preproc.pyc")

                        if use_log:
                            preproc.logger = logger
                            logger.info(
                                f"model saved based per on: {dev_name} dataset"
                            )

                        print(
                            f"UPDATED: best_model based on PER {best_so_far} for {dev_name} devset"
                        )

            per_diff_dict = calc_per_difference(dev_per_dict)

            tbX_writer.add_scalars('dev/loss', dev_loss_dict, epoch)
            tbX_writer.add_scalars('dev/per', dev_per_dict, epoch)
            tbX_writer.add_scalars('dev/per/diff', per_diff_dict, epoch)
            gcs_ckpt_handler.upload_tensorboard_ckpt()

            learning_rate = list(optimizer.param_groups)[0]["lr"]
            # save the current state of training
            train_state = {
                "start_epoch": epoch + 1,
                "run_state": run_state,
                "best_so_far": best_so_far,
                "learning_rate": learning_rate
            }
            write_pickle(
                os.path.join(ckpt_cfg["local_save_path"],
                             "train_state.pickle"), train_state)
            gcs_ckpt_handler.upload_to_gcs("train_state.pickle")