Exemple #1
0
def maybe_load_checkpoint(model: Vicl, model_optimizer: LocalSgd, moptim_scheduler: ExponentialLR, task: int, models_dir: str):
    epoch = 0
    halo = Halo(text='Trying to load a checkpoint', spinner='dots').start()
    load_name = f'vicl-task-{task}-cp.pt'
    load_path = os.path.join(models_dir, load_name)

    try:
        checkpoint = torch.load(load_path, map_location=model.device())
    except Exception as e:
        halo.fail(f'No checkpoints found for this run: {e}')
    else:
        model.load_state(checkpoint['model'])
        model_optimizer.load_state_dict(checkpoint['model_optimizer'])
        moptim_scheduler.load_state_dict(checkpoint['moptim_scheduler'])
        epoch = checkpoint['epoch']
        halo.succeed(f'Found a checkpoint (epoch: {epoch})')

    return epoch
Exemple #2
0
def main(rank, args):

    # Distributed setup

    if args.distributed:
        setup_distributed(rank, args.world_size)

    not_main_rank = args.distributed and rank != 0

    logging.info("Start time: %s", datetime.now())

    # Explicitly set seed to make sure models created in separate processes
    # start from same random weights and biases
    torch.manual_seed(args.seed)

    # Empty CUDA cache
    torch.cuda.empty_cache()

    # Change backend for flac files
    torchaudio.set_audio_backend("soundfile")

    # Transforms

    melkwargs = {
        "n_fft": args.win_length,
        "n_mels": args.n_bins,
        "hop_length": args.hop_length,
    }

    sample_rate_original = 16000

    if args.type == "mfcc":
        transforms = torch.nn.Sequential(
            torchaudio.transforms.MFCC(
                sample_rate=sample_rate_original,
                n_mfcc=args.n_bins,
                melkwargs=melkwargs,
            ), )
        num_features = args.n_bins
    elif args.type == "waveform":
        transforms = torch.nn.Sequential(UnsqueezeFirst())
        num_features = 1
    else:
        raise ValueError("Model type not supported")

    if args.normalize:
        transforms = torch.nn.Sequential(transforms, Normalize())

    augmentations = torch.nn.Sequential()
    if args.freq_mask:
        augmentations = torch.nn.Sequential(
            augmentations,
            torchaudio.transforms.FrequencyMasking(
                freq_mask_param=args.freq_mask),
        )
    if args.time_mask:
        augmentations = torch.nn.Sequential(
            augmentations,
            torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask),
        )

    # Text preprocessing

    char_blank = "*"
    char_space = " "
    char_apostrophe = "'"
    labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
    language_model = LanguageModel(labels, char_blank, char_space)

    # Dataset

    training, validation = split_process_librispeech(
        [args.dataset_train, args.dataset_valid],
        [transforms, transforms],
        language_model,
        root=args.dataset_root,
        folder_in_archive=args.dataset_folder_in_archive,
    )

    # Decoder

    if args.decoder == "greedy":
        decoder = GreedyDecoder()
    else:
        raise ValueError("Selected decoder not supported")

    # Model

    model = Wav2Letter(
        num_classes=language_model.length,
        input_type=args.type,
        num_features=num_features,
    )

    if args.jit:
        model = torch.jit.script(model)

    if args.distributed:
        n = torch.cuda.device_count() // args.world_size
        devices = list(range(rank * n, (rank + 1) * n))
        model = model.to(devices[0])
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=devices)
    else:
        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
        model = model.to(devices[0], non_blocking=True)
        model = torch.nn.DataParallel(model)

    n = count_parameters(model)
    logging.info("Number of parameters: %s", n)

    # Optimizer

    if args.optimizer == "adadelta":
        optimizer = Adadelta(
            model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
            eps=args.eps,
            rho=args.rho,
        )
    elif args.optimizer == "sgd":
        optimizer = SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adam":
        optimizer = Adam(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adamw":
        optimizer = AdamW(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    else:
        raise ValueError("Selected optimizer not supported")

    if args.scheduler == "exponential":
        scheduler = ExponentialLR(optimizer, gamma=args.gamma)
    elif args.scheduler == "reduceonplateau":
        scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3)
    else:
        raise ValueError("Selected scheduler not supported")

    criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank],
                                 zero_infinity=False)

    # Data Loader

    collate_fn_train = collate_factory(model_length_function, augmentations)
    collate_fn_valid = collate_factory(model_length_function)

    loader_training_params = {
        "num_workers": args.workers,
        "pin_memory": True,
        "shuffle": True,
        "drop_last": True,
    }
    loader_validation_params = loader_training_params.copy()
    loader_validation_params["shuffle"] = False

    loader_training = DataLoader(
        training,
        batch_size=args.batch_size,
        collate_fn=collate_fn_train,
        **loader_training_params,
    )
    loader_validation = DataLoader(
        validation,
        batch_size=args.batch_size,
        collate_fn=collate_fn_valid,
        **loader_validation_params,
    )

    # Setup checkpoint

    best_loss = 1.0

    load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint)

    if args.distributed:
        torch.distributed.barrier()

    if load_checkpoint:
        logging.info("Checkpoint: loading %s", args.checkpoint)
        checkpoint = torch.load(args.checkpoint)

        args.start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["best_loss"]

        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

        logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint,
                     checkpoint["epoch"])
    else:
        logging.info("Checkpoint: not found")

        save_checkpoint(
            {
                "epoch": args.start_epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            False,
            args.checkpoint,
            not_main_rank,
        )

    if args.distributed:
        torch.distributed.barrier()

    torch.autograd.set_detect_anomaly(False)

    for epoch in range(args.start_epoch, args.epochs):

        logging.info("Epoch: %s", epoch)

        train_one_epoch(
            model,
            criterion,
            optimizer,
            scheduler,
            loader_training,
            decoder,
            language_model,
            devices[0],
            epoch,
            args.clip_grad,
            not_main_rank,
            not args.reduce_lr_valid,
        )

        loss = evaluate(
            model,
            criterion,
            loader_validation,
            decoder,
            language_model,
            devices[0],
            epoch,
            not_main_rank,
        )

        if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(loss)

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            is_best,
            args.checkpoint,
            not_main_rank,
        )

    logging.info("End time: %s", datetime.now())

    if args.distributed:
        torch.distributed.destroy_process_group()
Exemple #3
0
# init?

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay_lr)
schedular = ExponentialLR(optimizer, gamma=args.decay_lr)


# Main training loop
best_loss = np.inf

# Resume training
if args.load_model is not None:
    if os.path.isfile(args.load_model):
        checkpoint = torch.load(args.load_model)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        schedular.load_state_dict(checkpoint['schedular'])
        best_loss = checkpoint['val_loss']
        epoch = checkpoint['epoch']
        print('Loading model: {}. Resuming from epoch: {}'.format(args.load_model, epoch))
    else:
        print('Model: {} not found'.format(args.load_model))

for epoch in range(args.epochs):
    v_loss = execute_graph(model, loader, optimizer, schedular, epoch, use_cuda)

    if v_loss < best_loss:
        best_loss = v_loss
        print('Writing model checkpoint')
        state = {
            'epoch': epoch,
            'model': model.state_dict(),
def main() -> None:
    """Entrypoint.
    """
    config: Any = importlib.import_module(args.config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_data = tx.data.MonoTextData(config.train_data_hparams, device=device)
    val_data = tx.data.MonoTextData(config.val_data_hparams, device=device)
    test_data = tx.data.MonoTextData(config.test_data_hparams, device=device)

    iterator = tx.data.DataIterator({
        "train": train_data,
        "valid": val_data,
        "test": test_data
    })

    opt_vars = {
        'learning_rate': config.lr_decay_hparams["init_lr"],
        'best_valid_nll': 1e100,
        'steps_not_improved': 0,
        'kl_weight': config.kl_anneal_hparams["start"]
    }

    decay_cnt = 0
    max_decay = config.lr_decay_hparams["max_decay"]
    decay_factor = config.lr_decay_hparams["decay_factor"]
    decay_ts = config.lr_decay_hparams["threshold"]

    save_dir = f"./models/{config.dataset}"

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    suffix = f"{config.dataset}_{config.decoder_type}Decoder.ckpt"

    save_path = os.path.join(save_dir, suffix)

    # KL term annealing rate
    anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] *
                      (len(train_data) / config.batch_size))

    vocab = train_data.vocab
    model = VAE(train_data.vocab.size, config)
    model.to(device)

    start_tokens = torch.full((config.batch_size, ),
                              vocab.bos_token_id,
                              dtype=torch.long).to(device)
    end_token = vocab.eos_token_id
    optimizer = tx.core.get_optimizer(params=model.parameters(),
                                      hparams=config.opt_hparams)
    scheduler = ExponentialLR(optimizer, decay_factor)

    def _run_epoch(epoch: int, mode: str, display: int = 10) \
            -> Tuple[Tensor, float]:
        iterator.switch_to_dataset(mode)

        if mode == 'train':
            model.train()
            opt_vars["kl_weight"] = min(1.0, opt_vars["kl_weight"] + anneal_r)

            kl_weight = opt_vars["kl_weight"]
        else:
            model.eval()
            kl_weight = 1.0
        step = 0
        start_time = time.time()
        num_words = 0
        nll_total = 0.

        avg_rec = tx.utils.AverageRecorder()
        for batch in iterator:
            ret = model(batch, kl_weight, start_tokens, end_token)
            if mode == "train":
                opt_vars["kl_weight"] = min(1.0,
                                            opt_vars["kl_weight"] + anneal_r)
                kl_weight = opt_vars["kl_weight"]
                ret["nll"].backward()
                optimizer.step()
                optimizer.zero_grad()

            batch_size = len(ret["lengths"])
            num_words += torch.sum(ret["lengths"]).item()
            nll_total += ret["nll"].item() * batch_size
            avg_rec.add([
                ret["nll"].item(), ret["kl_loss"].item(),
                ret["rc_loss"].item()
            ], batch_size)
            if step % display == 0 and mode == 'train':
                nll = avg_rec.avg(0)
                klw = opt_vars["kl_weight"]
                KL = avg_rec.avg(1)
                rc = avg_rec.avg(2)
                log_ppl = nll_total / num_words
                ppl = math.exp(log_ppl)
                time_cost = time.time() - start_time

                print(
                    f"{mode}: epoch {epoch}, step {step}, nll {nll:.4f}, "
                    f"klw {klw:.4f}, KL {KL:.4f}, rc {rc:.4f}, "
                    f"log_ppl {log_ppl:.4f}, ppl {ppl:.4f}, "
                    f"time_cost {time_cost:.1f}",
                    flush=True)

            step += 1

        nll = avg_rec.avg(0)
        KL = avg_rec.avg(1)
        rc = avg_rec.avg(2)
        log_ppl = nll_total / num_words
        ppl = math.exp(log_ppl)
        print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, "
              f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}")
        return nll, ppl  # type: ignore

    @torch.no_grad()
    def _generate(start_tokens: torch.LongTensor,
                  end_token: int,
                  filename: Optional[str] = None):
        ckpt = torch.load(args.model)
        model.load_state_dict(ckpt['model'])
        model.eval()

        batch_size = train_data.batch_size

        dst = MultivariateNormalDiag(loc=torch.zeros(batch_size,
                                                     config.latent_dims),
                                     scale_diag=torch.ones(
                                         batch_size, config.latent_dims))

        latent_z = dst.rsample().to(device)

        helper = model.decoder.create_helper(decoding_strategy='infer_sample',
                                             start_tokens=start_tokens,
                                             end_token=end_token)
        outputs = model.decode(helper=helper,
                               latent_z=latent_z,
                               max_decoding_length=100)

        sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())

        if filename is None:
            fh = sys.stdout
        else:
            fh = open(filename, 'w', encoding='utf-8')

        for sent in sample_tokens:
            sent = tx.utils.compat_as_text(list(sent))
            end_id = len(sent)
            if vocab.eos_token in sent:
                end_id = sent.index(vocab.eos_token)
            fh.write(' '.join(sent[:end_id + 1]) + '\n')

        print('Output done')
        fh.close()

    if args.mode == "predict":
        _generate(start_tokens, end_token, args.out)
        return
    # Counts trainable parameters
    total_parameters = sum(param.numel() for param in model.parameters())
    print(f"{total_parameters} total parameters")

    best_nll = best_ppl = 0.

    for epoch in range(config.num_epochs):
        _, _ = _run_epoch(epoch, 'train', display=200)
        val_nll, _ = _run_epoch(epoch, 'valid')
        test_nll, test_ppl = _run_epoch(epoch, 'test')

        if val_nll < opt_vars['best_valid_nll']:
            opt_vars['best_valid_nll'] = val_nll
            opt_vars['steps_not_improved'] = 0
            best_nll = test_nll
            best_ppl = test_ppl

            states = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict()
            }
            torch.save(states, save_path)
        else:
            opt_vars['steps_not_improved'] += 1
            if opt_vars['steps_not_improved'] == decay_ts:
                old_lr = opt_vars['learning_rate']
                opt_vars['learning_rate'] *= decay_factor
                opt_vars['steps_not_improved'] = 0
                new_lr = opt_vars['learning_rate']
                ckpt = torch.load(save_path)
                model.load_state_dict(ckpt['model'])
                optimizer.load_state_dict(ckpt['optimizer'])
                scheduler.load_state_dict(ckpt['scheduler'])
                scheduler.step()
                print(f"-----\nchange lr, old lr: {old_lr}, "
                      f"new lr: {new_lr}\n-----")

                decay_cnt += 1
                if decay_cnt == max_decay:
                    break

    print(f"\nbest testing nll: {best_nll:.4f},"
          f"best testing ppl {best_ppl:.4f}\n")
Exemple #5
0
class Parser(object):

    NAME = None
    MODEL = None

    def __init__(self, args, model, transform):
        self.args = args
        self.model = model
        self.transform = transform

    def train(self,
              train,
              dev,
              test,
              buckets=32,
              batch_size=5000,
              update_steps=1,
              clip=5.0,
              epochs=5000,
              patience=100,
              **kwargs):
        args = self.args.update(locals())
        init_logger(logger, verbose=args.verbose)

        self.transform.train()
        batch_size = batch_size // update_steps
        if dist.is_initialized():
            batch_size = batch_size // dist.get_world_size()
        logger.info("Loading the data")
        train = Dataset(self.transform, args.train,
                        **args).build(batch_size, buckets, True,
                                      dist.is_initialized())
        dev = Dataset(self.transform, args.dev).build(batch_size, buckets)
        test = Dataset(self.transform, args.test).build(batch_size, buckets)
        logger.info(
            f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")

        if args.encoder == 'lstm':
            self.optimizer = Adam(self.model.parameters(), args.lr,
                                  (args.mu, args.nu), args.eps,
                                  args.weight_decay)
            self.scheduler = ExponentialLR(self.optimizer,
                                           args.decay**(1 / args.decay_steps))
        else:
            from transformers import AdamW, get_linear_schedule_with_warmup
            steps = len(train.loader) * epochs // args.update_steps
            self.optimizer = AdamW([{
                'params':
                p,
                'lr':
                args.lr * (1 if n.startswith('encoder') else args.lr_rate)
            } for n, p in self.model.named_parameters()], args.lr)
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer, int(steps * args.warmup), steps)

        if dist.is_initialized():
            self.model = DDP(self.model,
                             device_ids=[args.local_rank],
                             find_unused_parameters=True)

        self.epoch, self.best_e, self.patience, self.best_metric, self.elapsed = 1, 1, patience, Metric(
        ), timedelta()
        if self.args.checkpoint:
            self.optimizer.load_state_dict(
                self.checkpoint_state_dict.pop('optimizer_state_dict'))
            self.scheduler.load_state_dict(
                self.checkpoint_state_dict.pop('scheduler_state_dict'))
            set_rng_state(self.checkpoint_state_dict.pop('rng_state'))
            for k, v in self.checkpoint_state_dict.items():
                setattr(self, k, v)
            train.loader.batch_sampler.epoch = self.epoch

        for epoch in range(self.epoch, args.epochs + 1):
            start = datetime.now()

            logger.info(f"Epoch {epoch} / {args.epochs}:")
            self._train(train.loader)
            loss, dev_metric = self._evaluate(dev.loader)
            logger.info(f"{'dev:':5} loss: {loss:.4f} - {dev_metric}")
            loss, test_metric = self._evaluate(test.loader)
            logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}")

            t = datetime.now() - start
            self.epoch += 1
            self.patience -= 1
            self.elapsed += t

            if dev_metric > self.best_metric:
                self.best_e, self.patience, self.best_metric = epoch, patience, dev_metric
                if is_master():
                    self.save_checkpoint(args.path)
                logger.info(f"{t}s elapsed (saved)\n")
            else:
                logger.info(f"{t}s elapsed\n")
            if self.patience < 1:
                break
        parser = self.load(**args)
        loss, metric = parser._evaluate(test.loader)
        parser.save(args.path)

        logger.info(f"Epoch {self.best_e} saved")
        logger.info(f"{'dev:':5} {self.best_metric}")
        logger.info(f"{'test:':5} {metric}")
        logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch")

    def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
        args = self.args.update(locals())
        init_logger(logger, verbose=args.verbose)

        self.transform.train()
        logger.info("Loading the data")
        dataset = Dataset(self.transform, data)
        dataset.build(batch_size, buckets)
        logger.info(f"\n{dataset}")

        logger.info("Evaluating the dataset")
        start = datetime.now()
        loss, metric = self._evaluate(dataset.loader)
        elapsed = datetime.now() - start
        logger.info(f"loss: {loss:.4f} - {metric}")
        logger.info(
            f"{elapsed}s elapsed, {len(dataset)/elapsed.total_seconds():.2f} Sents/s"
        )

        return loss, metric

    def predict(self,
                data,
                pred=None,
                lang=None,
                buckets=8,
                batch_size=5000,
                prob=False,
                **kwargs):
        args = self.args.update(locals())
        init_logger(logger, verbose=args.verbose)

        self.transform.eval()
        if args.prob:
            self.transform.append(Field('probs'))

        logger.info("Loading the data")
        dataset = Dataset(self.transform, data, lang=lang)
        dataset.build(batch_size, buckets)
        logger.info(f"\n{dataset}")

        logger.info("Making predictions on the dataset")
        start = datetime.now()
        preds = self._predict(dataset.loader)
        elapsed = datetime.now() - start

        for name, value in preds.items():
            setattr(dataset, name, value)
        if pred is not None and is_master():
            logger.info(f"Saving predicted results to {pred}")
            self.transform.save(pred, dataset.sentences)
        logger.info(
            f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s"
        )

        return dataset

    def _train(self, loader):
        raise NotImplementedError

    @torch.no_grad()
    def _evaluate(self, loader):
        raise NotImplementedError

    @torch.no_grad()
    def _predict(self, loader):
        raise NotImplementedError

    @classmethod
    def build(cls, path, **kwargs):
        raise NotImplementedError

    @classmethod
    def load(cls,
             path,
             reload=False,
             src='github',
             checkpoint=False,
             **kwargs):
        r"""
        Loads a parser with data fields and pretrained model parameters.

        Args:
            path (str):
                - a string with the shortcut name of a pretrained model defined in ``supar.MODEL``
                  to load from cache or download, e.g., ``'biaffine-dep-en'``.
                - a local path to a pretrained model, e.g., ``./<path>/model``.
            reload (bool):
                Whether to discard the existing cache and force a fresh download. Default: ``False``.
            src (str):
                Specifies where to download the model.
                ``'github'``: github release page.
                ``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
                Default: ``'github'``.
            checkpoint (bool):
                If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
            kwargs (dict):
                A dict holding unconsumed arguments for updating training configs and initializing the model.

        Examples:
            >>> from supar import Parser
            >>> parser = Parser.load('biaffine-dep-en')
            >>> parser = Parser.load('./ptb.biaffine.dep.lstm.char')
        """

        args = Config(**locals())
        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        state = torch.load(path if os.path.exists(path) else download(
            supar.MODEL[src].get(path, path), reload=reload))
        cls = supar.PARSER[state['name']] if cls.NAME is None else cls
        args = state['args'].update(args)
        model = cls.MODEL(**args)
        model.load_pretrained(state['pretrained'])
        model.load_state_dict(state['state_dict'], False)
        model.to(args.device)
        transform = state['transform']
        parser = cls(args, model, transform)
        parser.checkpoint_state_dict = state[
            'checkpoint_state_dict'] if args.checkpoint else None
        return parser

    def save(self, path):
        model = self.model
        if hasattr(model, 'module'):
            model = self.model.module
        args = model.args
        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        pretrained = state_dict.pop('pretrained.weight', None)
        state = {
            'name': self.NAME,
            'args': args,
            'state_dict': state_dict,
            'pretrained': pretrained,
            'transform': self.transform
        }
        torch.save(state, path, pickle_module=dill)

    def save_checkpoint(self, path):
        model = self.model
        if hasattr(model, 'module'):
            model = self.model.module
        args = model.args
        checkpoint_state_dict = {
            k: getattr(self, k)
            for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']
        }
        checkpoint_state_dict.update({
            'optimizer_state_dict':
            self.optimizer.state_dict(),
            'scheduler_state_dict':
            self.scheduler.state_dict(),
            'rng_state':
            get_rng_state()
        })
        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        pretrained = state_dict.pop('pretrained.weight', None)
        state = {
            'name': self.NAME,
            'args': args,
            'state_dict': state_dict,
            'pretrained': pretrained,
            'checkpoint_state_dict': checkpoint_state_dict,
            'transform': self.transform
        }
        torch.save(state, path, pickle_module=dill)
def main():
    """Entrypoint.
    """
    config: Any = importlib.import_module(args.config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # train_data = tx.data.MonoTextData(config.train_data_hparams, device=device)
    # val_data = tx.data.MonoTextData(config.val_data_hparams, device=device)
    # test_data = tx.data.MonoTextData(config.test_data_hparams, device=device)

    train_data = tx.data.MonoTextData(config.train_data_hparams,
                                      device=torch.device("cpu"))
    val_data = tx.data.MonoTextData(config.val_data_hparams,
                                    device=torch.device("cpu"))
    test_data = tx.data.MonoTextData(config.test_data_hparams,
                                     device=torch.device("cpu"))

    iterator = tx.data.DataIterator({
        "train": train_data,
        "valid": val_data,
        "test": test_data
    })

    opt_vars = {
        'learning_rate': config.lr_decay_hparams["init_lr"],
        'best_valid_nll': 1e100,
        'steps_not_improved': 0,
        'kl_weight': config.kl_anneal_hparams["start"]
    }

    decay_cnt = 0
    max_decay = config.lr_decay_hparams["max_decay"]
    decay_factor = config.lr_decay_hparams["decay_factor"]
    decay_ts = config.lr_decay_hparams["threshold"]

    if 'pid' in args.model_name:
        save_dir = args.model_name + '_' + str(config.dataset) + '_KL' + str(
            args.exp_kl)
    elif 'cost' in args.model_name:
        save_dir = args.model_name + '_' + str(config.dataset) + '_step' + str(
            args.anneal_steps)
    elif 'cyclical' in args.model_name:
        save_dir = args.model_name + '_' + str(config.dataset) + '_cyc_' + str(
            args.cycle)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    suffix = f"{config.dataset}_{config.decoder_type}Decoder.ckpt"

    save_path = os.path.join(save_dir, suffix)

    # KL term annealing rate warm_up=10
    ## replace it with sigmoid function
    anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] *
                      (len(train_data) / config.batch_size))

    vocab = train_data.vocab
    model = VAE(train_data.vocab.size, config)
    model.to(device)

    start_tokens = torch.full((config.batch_size, ),
                              vocab.bos_token_id,
                              dtype=torch.long).to(device)
    end_token = vocab.eos_token_id
    optimizer = tx.core.get_optimizer(params=model.parameters(),
                                      hparams=config.opt_hparams)
    scheduler = ExponentialLR(optimizer, decay_factor)

    ## max iteration
    max_iter = config.num_epochs * len(train_data) / config.batch_size
    max_iter = min(max_iter, args.max_steps)
    print('max steps:', max_iter)
    pbar = tqdm(total=int(max_iter))

    if args.mode == "train":
        outFile = os.path.join(save_dir, 'train.log')
        fw_log = open(outFile, "w")

    global_steps = {}
    global_steps['step'] = 0
    pid = PIDControl()
    opt_vars["kl_weight"] = 0.0
    Kp = args.Kp
    Ki = args.Ki
    exp_kl = args.exp_kl

    ## train model
    def _run_epoch(epoch: int, mode: str, display: int = 10) \
            -> Tuple[Tensor, float]:
        iterator.switch_to_dataset(mode)

        if mode == 'train':
            model.train()
            kl_weight = opt_vars["kl_weight"]
        else:
            model.eval()
            kl_weight = 1.0
            # kl_weight = opt_vars["kl_weight"]

        start_time = time.time()
        num_words = 0
        nll_total = 0.

        avg_rec = tx.utils.AverageRecorder()
        for batch in iterator:
            ## run model to get loss function
            if global_steps['step'] >= args.max_steps:
                break
            ret = model(batch, kl_weight, start_tokens, end_token)
            if mode == "train":
                pbar.update(1)
                global_steps['step'] += 1
                kl_loss = ret['kl_loss'].item()
                rec_loss = ret['rc_loss'].item()
                total_loss = ret["nll"].item()
                if 'cost' in args.model_name:
                    kl_weight = _cost_annealing(global_steps['step'], 1.0,
                                                args.anneal_steps)
                elif 'pid' in args.model_name:
                    kl_weight = pid.pid(exp_kl, kl_loss, Kp, Ki)
                elif 'cyclical' in args.model_name:
                    kl_weight = _cyclical_annealing(global_steps['step'],
                                                    max_iter / args.cycle)

                opt_vars["kl_weight"] = kl_weight

                ## total loss
                ret["nll"].backward()
                optimizer.step()
                optimizer.zero_grad()
                fw_log.write('epoch:{0} global_step:{1} total_loss:{2:.3f} kl_loss:{3:.3f} rec_loss:{4:.3f} kl_weight:{5:.4f}\n'\
                            .format(epoch, global_steps['step'], total_loss, kl_loss, rec_loss, kl_weight))
                fw_log.flush()

            batch_size = len(ret["lengths"])
            num_words += torch.sum(ret["lengths"]).item()
            nll_total += ret["nll"].item() * batch_size
            avg_rec.add([
                ret["nll"].item(), ret["kl_loss"].item(),
                ret["rc_loss"].item()
            ], batch_size)

            if global_steps['step'] % display == 1 and mode == 'train':
                nll = avg_rec.avg(0)
                klw = opt_vars["kl_weight"]
                KL = avg_rec.avg(1)
                rc = avg_rec.avg(2)
                writer.add_scalar(f'Loss/Rec_loss_{args.model_name}', rc,
                                  global_steps['step'])
                writer.add_scalar(f'Loss/KL_diverg_{args.model_name}', KL,
                                  global_steps['step'])
                writer.add_scalar(f'Loss/KL_weight_{args.model_name}', klw,
                                  global_steps['step'])

        nll = avg_rec.avg(0)
        KL = avg_rec.avg(1)
        rc = avg_rec.avg(2)
        if num_words > 0:
            log_ppl = nll_total / num_words
            ppl = math.exp(log_ppl)
        else:
            log_ppl = 100
            ppl = math.exp(log_ppl)
            nll = 1000
            KL = args.exp_kl

        print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, "
              f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}")
        return nll, ppl  # type: ignore

    args.model = save_path

    @torch.no_grad()
    def _generate(start_tokens: torch.LongTensor,
                  end_token: int,
                  filename: Optional[str] = None):
        ckpt = torch.load(args.model)
        model.load_state_dict(ckpt['model'])
        model.eval()

        batch_size = train_data.batch_size

        dst = MultivariateNormalDiag(loc=torch.zeros(batch_size,
                                                     config.latent_dims),
                                     scale_diag=torch.ones(
                                         batch_size, config.latent_dims))

        # latent_z = dst.rsample().to(device)
        latent_z = torch.FloatTensor(batch_size,
                                     config.latent_dims).uniform_(-1,
                                                                  1).to(device)
        # latent_z = torch.randn(batch_size, config.latent_dims).to(device)

        helper = model.decoder.create_helper(decoding_strategy='infer_sample',
                                             start_tokens=start_tokens,
                                             end_token=end_token)
        outputs = model.decode(helper=helper,
                               latent_z=latent_z,
                               max_decoding_length=100)

        if config.decoder_type == "transformer":
            outputs = outputs[0]

        sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())

        if filename is None:
            fh = sys.stdout
        else:
            fh = open(filename, 'a', encoding='utf-8')

        for sent in sample_tokens:
            sent = tx.utils.compat_as_text(list(sent))
            end_id = len(sent)
            if vocab.eos_token in sent:
                end_id = sent.index(vocab.eos_token)
            fh.write(' '.join(sent[:end_id + 1]) + '\n')

        print('Output done')
        fh.close()

    if args.mode == "predict":
        out_path = os.path.join(save_dir, 'results.txt')
        for _ in range(10):
            _generate(start_tokens, end_token, out_path)
        return

    # Counts trainable parameters
    total_parameters = sum(param.numel() for param in model.parameters())
    print(f"{total_parameters} total parameters")

    best_nll = best_ppl = 0.

    ## start running model
    for epoch in range(config.num_epochs):
        _, _ = _run_epoch(epoch, 'train', display=200)
        val_nll, _ = _run_epoch(epoch, 'valid')
        test_nll, test_ppl = _run_epoch(epoch, 'test')

        if val_nll < opt_vars['best_valid_nll']:
            opt_vars['best_valid_nll'] = val_nll
            opt_vars['steps_not_improved'] = 0
            best_nll = test_nll
            best_ppl = test_ppl

            states = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict()
            }
            torch.save(states, save_path)
        else:
            opt_vars['steps_not_improved'] += 1
            if opt_vars['steps_not_improved'] == decay_ts:
                old_lr = opt_vars['learning_rate']
                opt_vars['learning_rate'] *= decay_factor
                opt_vars['steps_not_improved'] = 0
                new_lr = opt_vars['learning_rate']
                ckpt = torch.load(save_path)
                model.load_state_dict(ckpt['model'])
                optimizer.load_state_dict(ckpt['optimizer'])
                scheduler.load_state_dict(ckpt['scheduler'])
                scheduler.step()
                print(f"-----\nchange lr, old lr: {old_lr}, "
                      f"new lr: {new_lr}\n-----")

                decay_cnt += 1
                if decay_cnt == max_decay:
                    break
        if global_steps['step'] >= args.max_steps:
            break

    print(f"\nbest testing nll: {best_nll:.4f},"
          f"best testing ppl {best_ppl:.4f}\n")

    if args.mode == "train":
        fw_log.write(f"\nbest testing nll: {best_nll:.4f},"
                     f"best testing ppl {best_ppl:.4f}\n")
        fw_log.close()
Exemple #7
0
def main_worker(train_loader, val_loader, ntokens, args, device):
    global best_ppl

    model_kwargs = {
        'dropout': args.dropout,
        'tie_weights': not args.not_tied,
        'norm': args.norm_mode,
        'alpha_fwd': args.afwd,
        'alpha_bkw': args.abkw,
        'batch_size': args.batch_size,  # Deprecated
        'ecm': args.ecm,
        'cell_norm': args.cell_norm,
    }

    # create model
    print("=> creating model: '{}'".format(args.ru_type))
    model = models.RNNModel(args.ru_type, ntokens, args.emsize, args.nhid,
                            args.nlayers, **model_kwargs).to(device)

    print(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                weight_decay=args.weight_decay)

    scheduler = ExponentialLR(optimizer, gamma=1 / args.lr_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_ppl = checkpoint['best_ppl']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = False if args.seed else True

    if args.evaluate:
        validate(val_loader, model, criterion, device, args, ntokens)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if epoch: scheduler.step()

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, device, args,
              ntokens)

        # evaluate on validation set
        ppl = validate(val_loader, model, criterion, device, args, ntokens)

        # remember best ppl and save checkpoint
        is_best = ppl < best_ppl
        best_ppl = min(ppl, best_ppl)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_ppl': best_ppl,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, is_best, args)