Ejemplo n.º 1
0
    def predict(self):
        dataset = create_dataset(
            self.opt
        )  # create a dataset given opt.dataset_mode and other options
        model = create_model(
            self.opt)  # create a model given opt.model and other options
        model.setup(
            self.opt
        )  # regular setup: load and print networks; create schedulers

        img_rec = None

        if self.opt.eval:
            model.eval()
        for i, data in enumerate(dataset):
            model.set_input(data)  # unpack data from data loader
            model.test()  # run inference
            visuals = model.get_current_visuals()  # get image results

            im = visuals['fake_B']
            img_rec = util.tensor2im(im)

        return img_rec
Ejemplo n.º 2
0
def main():
    """
    Performs training, validation and testing.
    """
    args = setup_train_args()

    args.cuda = torch.cuda.is_available() \
        and not args.no_cuda

    model_dir = join(args.model_dir, args.model, args.name)

    os.makedirs(model_dir, exist_ok=True)

    logger = create_logger(model_dir=model_dir)

    if args.mixed and not APEX_INSTALLED:
        logger.warn('--mixed passed but apex is not installed.')

    args.mixed = args.mixed and APEX_INSTALLED \
        and args.cuda

    master_process = args.local_rank in [0, -1]
    args.distributed = args.local_rank > 0

    if args.distributed:
        # use distributed training if local rank is given
        # and GPU training is requested
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)

        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://',
                                             rank=args.local_rank)

    else:
        device = torch.device('cuda' if args.cuda else 'cpu')

    # creating dataset and storing dataset splits
    # as individual variables for convenience
    datasets, tokenizer = create_dataset(args=args,
                                         master_process=master_process)

    pad_idx = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    vocab_size = len(tokenizer)

    # TODO fix xlnet nan with mixed precision
    if 'xlnet' in args.model:
        args.mixed = False

    model = create_model(args=args, model_dir=model_dir, vocab_size=vocab_size)

    model = model.to(device)

    optimizer = create_optimizer(args=args, parameters=model.parameters())

    if master_process:
        writer = SummaryWriter(logdir=model_dir, flush_secs=100)

    # loading previous state of the training
    best_val_loss, init_epoch, step = load_state(model_dir=model_dir,
                                                 model=model,
                                                 optimizer=optimizer,
                                                 logger=logger,
                                                 device=device)

    if args.mixed:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    world_size = int(os.environ.get('WORLD_SIZE', 1))

    train, valid, test = [(split, ceil(size / args.batch_size / world_size))
                          for split, size in datasets]

    # computing the sizes of the dataset splits
    train_dataset, num_train_steps = train
    valid_dataset, num_valid_steps = valid
    test_dataset, num_test_steps = test

    patience, skip, loss, acc = 0, 0, 0, 0

    def reduce_tensor(tensor):
        """
        Averages a tensor across gpus.
        """
        reduced = tensor.clone()
        all_reduce(reduced, op=ReduceOp.SUM)
        reduced /= world_size

        return reduced

    def forward_step(batch):
        """
        Applies forward pass with the given batch.
        """
        inputs, targets = batch

        outputs = model(inputs=inputs, half=args.mixed)

        # converting targets from ndarray
        targets = torch.as_tensor(targets)
        targets = targets.long().to(device)

        loss, accuracy = compute_loss(outputs=outputs,
                                      targets=targets,
                                      ignore_idx=pad_idx)

        if args.distributed:
            # reducing accuracy accross devices
            # for more accurate logging
            accuracy = reduce_tensor(accuracy)

        return loss, accuracy.item()

    def train_step(batch):
        """
        Performs a single step of training.
        """
        nonlocal step, skip

        loss, accuracy = forward_step(batch)

        if torch.isnan(loss).item():
            logger.debug('skipping step (nan)')
            # returning None values when a NaN loss
            # is encountered and skipping backprop
            # so model grads will not be corrupted
            skip += 1
            return None, None

        loss /= args.grad_accum_steps

        backward(loss)
        clip_grad_norm(1.0)

        step += 1

        if step % args.grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        if args.distributed:
            # reducing loss accross devices for
            # more accurate logging
            loss = reduce_tensor(loss)

        return loss.item(), accuracy

    def backward(loss):
        """
        Backpropagates the loss in either mixed or
        normal precision mode.
        """
        # cuda is required for mixed precision training.
        if args.mixed:
            with amp.scale_loss(loss, optimizer) as scaled:
                scaled.backward()
        else:
            loss.backward()

    def clip_grad_norm(max_norm):
        """
        Applies gradient clipping.
        """
        if args.mixed:
            clip_grad_norm_(amp.master_params(optimizer), max_norm)
        else:
            clip_grad_norm_(model.parameters(), max_norm)

    def evaluate(dataset, num_steps):
        """
        Constructs a validation loader and evaluates
        the model.
        """
        loop = tqdm(dataset(),
                    total=num_steps,
                    disable=not master_process,
                    desc='Eval')

        model.eval()

        for batch in loop:
            loss, acc = forward_step(batch)

            loop.set_postfix(
                ordered_dict=OrderedDict(loss=loss.item(), acc=acc))

            yield loss.item()

    def save_state():
        """
        Saves the model and optimizer state.
        """
        model_path = join(model_dir, 'model.pt')

        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'val_loss': best_val_loss,
            'epoch': epoch + 1,
            'step': step
        }

        logger.info('Saving model to {}'.format(model_path))
        # making sure the model saving is not left in a
        # corrupted state after a keyboard interrupt
        while True:
            try:
                torch.save(state, model_path)
                break
            except KeyboardInterrupt:
                pass

    scheduler = LambdaLR(optimizer, compute_lr)

    if master_process:
        logger.info(str(vars(args)))

    for epoch in range(init_epoch, args.max_epochs):
        # running training loop
        loop = tqdm(train_dataset(),
                    total=num_train_steps,
                    disable=not master_process,
                    desc='Train {}'.format(epoch))

        train_loss = []

        model.train()

        for batch in loop:
            try:
                loss, acc = train_step(batch)

                if master_process and loss is not None:
                    train_loss.append(loss)

                    # logging to tensorboard
                    writer.add_scalar('train/loss', loss, step)
                    writer.add_scalar('train/acc', acc, step)

                if not step % args.eval_every_step:
                    with torch.no_grad():
                        val_loss = mean(
                            evaluate(dataset=valid_dataset,
                                     num_steps=num_valid_steps))

                    # switching back to training
                    model.train()

                    if master_process:
                        logger.info('val loss: {:.4}'.format(val_loss))

                        # logging to tensorboard
                        writer.add_scalar('val/loss', val_loss, step)

                    if val_loss < best_val_loss:
                        patience = 0
                        best_val_loss = val_loss

                        if master_process:
                            save_state()

                    else:
                        patience += 1
                        if patience == args.patience:
                            # terminate when max patience
                            # level is hit
                            break

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    logger.debug('skipping step (oom)')
                    skip += 1

            loop.set_postfix(
                ordered_dict=OrderedDict(loss=loss, acc=acc, skip=skip))

        if len(train_loss) > 0:
            train_loss = mean(train_loss)
        else:
            train_loss = 0.0

        if master_process:
            logger.info('train loss: {:.4}'.format(train_loss))

        scheduler.step()

    if master_process:
        writer.close()

    with torch.no_grad():
        test_loss = mean(
            evaluate(dataset=test_dataset, num_steps=num_test_steps))

    if master_process:
        logger.info('test loss: {:.4}'.format(test_loss))
Ejemplo n.º 3
0
"""
from src.options.test_options import TestOptions
from src.data import create_dataset
from src.models import create_model
from src.util.visualizer import save_images_to_path
import sys

if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.num_threads = 0  # test code only supports num_threads = 1
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True  # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1  # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers

    if opt.force_test_output == '':
        print('--force_test_output must be defined.')
        sys.exit()

    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
Ejemplo n.º 4
0
def main():
    args = setup_interact_args()
    args.distributed = False

    device = torch.device('cuda' if args.cuda else 'cpu')

    model_dir = join(args.model_dir, args.model_name)

    state_dict = torch.load(join(model_dir, 'model.pt'), map_location=device)

    _, tokenizer = create_dataset(args=args)

    vocab_size = len(tokenizer)

    model = create_model(args, vocab_size)
    model = model.to(device)

    model.load_state_dict(state_dict['model'])
    model.eval()

    history = []

    select_fn = METHODS[args.method]

    special_ids = tokenizer.convert_tokens_to_ids([
        SP1,
        SP2,
        tokenizer.bos_token,
        tokenizer.eos_token,
        HST,
        RSP,
    ])

    @torch.no_grad()
    def respond(text):
        """
        Responds to the given text.
        """
        history.append(tokenizer.encode(text))

        inputs = transform_dialog(history[:args.max_hist],
                                  special_ids=special_ids)

        preds = decode(args=args,
                       model=model,
                       inputs=inputs,
                       tokenizer=tokenizer,
                       select_fn=select_fn,
                       device=device)

        history.append(preds)

        # last token is the end token
        return tokenizer.decode(preds[:-1])

    print('Type a sentence to translate. ' + 'CTRL + C to escape.')

    while True:
        try:
            print()
            text = input()
            output = respond(text)
            print(output)
            print()

        except KeyboardInterrupt:
            break
Ejemplo n.º 5
0
def main(cfg):
    """
    Performs training, validation and testing.
    """
    assert isdir(cfg.data_dir), \
        '`data_dir` must be a valid path.'

    cfg.cuda = torch.cuda.is_available() \
        and not cfg.no_cuda

    cfg.model_dir = os.getcwd()

    # setting random seed for reproducibility
    if cfg.seed: set_random_seed(cfg)

    device = torch.device('cuda' if cfg.cuda else 'cpu')

    os.makedirs(cfg.model_dir, exist_ok=True)

    label2id = create_label2id(cfg)
    cfg.num_labels = len(label2id)

    xlmr = create_pretrained(cfg.model_type, cfg.force_download)

    # creating dataset split loaders
    datasets = create_dataset(cfg, xlmr, label2id)

    train_dataset, valid_dataset = datasets

    def compute_loss(batch):
        """
        Computes the forward pass and returns the
        cross entropy loss.
        """
        inputs, labels = [
            torch.from_numpy(tensor).to(device).long() for tensor in batch
        ]

        logits = model(inputs)

        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        loss = torch.nn.functional.cross_entropy(logits,
                                                 labels,
                                                 ignore_index=-1)

        return loss

    def train_step(engine, batch):
        """
        Propagates the inputs forward and updates
        the parameters.
        """
        step = engine.state.iteration

        model.train()

        loss = compute_loss(batch)

        backward(loss)

        if cfg.clip_grad_norm is not None:
            clip_grad_norm(cfg.clip_grad_norm)

        if step % cfg.grad_accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        # restoring the averaged loss across steps
        loss *= cfg.grad_accum_steps

        return loss.item()

    def eval_step(engine, batch):
        """
        Propagates the inputs forward without
        storing any gradients.
        """
        model.eval()

        with torch.no_grad():
            loss = compute_loss(batch)

        return loss.item()

    def backward(loss):
        """
        Backpropagates the loss in either mixed or
        normal precision mode.
        """
        if cfg.fp16:
            with amp.scale_loss(loss, optimizer) as sc:
                sc.backward()

        else:
            loss.backward()

    def clip_grad_norm(max_norm):
        """
        Applies gradient clipping.
        """
        if cfg.fp16:
            params = amp.master_params(optimizer)
        else:
            params = model.parameters()

        torch.nn.utils.clip_grad_norm_(params, max_norm)

    trainer = Engine(train_step)
    validator = Engine(eval_step)

    checkpoint = ModelCheckpoint(
        cfg.model_dir,
        cfg.model_type,
        n_saved=5,
        save_as_state_dict=True,
        score_function=lambda e: -e.state.metrics['loss'])

    last_ckpt_path = cfg.ckpt_path

    if last_ckpt_path is not None:
        msg = 'Loading state from {}'
        print(msg.format(basename(last_ckpt_path)))

        last_state = torch.load(last_ckpt_path, map_location=device)

    model = create_model(xlmr, len(label2id), cfg)
    model = model.to(device)

    del xlmr.model

    optimizer = create_optimizer(cfg, model)

    scheduler = create_scheduler(cfg, optimizer, len(train_dataset))

    # using apex if required and loading its state
    if cfg.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

        if last_ckpt_path is not None and \
                'amp' in last_state:
            amp.load_state_dict(last_state['amp'])

    if last_ckpt_path is not None:
        model.load_state_dict(last_state['model'])
        optimizer.load_state_dict(last_state['optimizer'])
        scheduler.load_state_dict(last_state['scheduler'])

    checkpoint_dict = {
        'model': model,
        'optimizer': optimizer,
        'scheduler': scheduler
    }

    if cfg.fp16: checkpoint_dict['amp'] = amp

    validator.add_event_handler(Events.COMPLETED, checkpoint, checkpoint_dict)

    metric = RunningAverage(output_transform=lambda x: x)
    metric.attach(trainer, 'loss')
    metric.attach(validator, 'loss')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=['loss'])

    history_path = join(cfg.model_dir, 'history.json')
    history = collections.defaultdict(list)
    headers = ['epoch', 'train_loss', 'valid_loss']

    if exists(history_path):
        with open(history_path, 'r') as fh:
            history = json.load(fh)

    def record_history(results):
        """
        Records the results to the history.
        """
        for header in headers:
            history[header].append(results[header])

        with open(history_path, 'w') as fh:
            json.dump(history, fh)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_results(engine):
        """
        Logs the training results.
        """
        validator.run(valid_dataset)

        record_history({
            'epoch': engine.state.epoch,
            'train_loss': engine.state.metrics['loss'],
            'valid_loss': validator.state.metrics['loss']
        })

        data = list(zip(*[history[h] for h in headers]))
        table = tabulate(data, headers, floatfmt='.3f')

        print(table.split('\n')[-1])

    data = list(zip(*[history[h] for h in headers]))

    print()
    print(cfg.pretty())

    print()
    print('***** Running training *****')

    print()
    print(tabulate(data, headers, floatfmt='.3f'))

    trainer.run(train_dataset, cfg.max_epochs)
Ejemplo n.º 6
0
def main():
    """
    Performs training, validation and testing.
    """
    args = setup_train_args()

    if args.notebook:
        from tqdm import tqdm_notebook as tqdm
    else:
        from tqdm import tqdm

    # if config is provided, then load it
    if args.config is not None:
        with open(args.config, 'r') as fh:
            config = json.load(fh)

        for arg in config:
            setattr(args, arg, config[arg])

    args.cuda = torch.cuda.is_available() \
        and not args.no_cuda

    # setting random seed for reproducibility
    if args.seed:
        set_random_seed(args)

    model_dir = join(args.model_dir, args.model, args.name)

    os.makedirs(model_dir, exist_ok=True)
    logger = create_logger(model_dir=model_dir)

    if args.fp16 and not APEX_INSTALLED:
        logger.warn('--fp16 passed but apex is not installed.')

    args.fp16 = args.fp16 and APEX_INSTALLED \
        and args.cuda

    master_process = args.local_rank in [0, -1]
    args.distributed = args.local_rank != -1

    if args.distributed:
        # use distributed training if local rank is given
        # and GPU training is requested
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)

        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://',
                                             rank=args.local_rank)

    else:
        device = torch.device('cuda' if args.cuda else 'cpu')

    # creating dataset and storing dataset splits
    # as individual variables for convenience

    if args.distributed:
        # creating the dataset and model only on
        # a single process ( downloading )
        if master_process:
            _, tokenizer, _ = create_dataset(args, master_process)

            vocab_size = len(tokenizer)

            create_model(args, model_dir, vocab_size)

        # other threads are waiting for the data init
        barrier()

    datasets, tokenizer, max_len = create_dataset(
        args=args, master_process=master_process)

    pad_idx = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    vocab_size = len(tokenizer)

    model = create_model(args, model_dir, vocab_size)
    model = model.to(device)

    # TODO fix xlnet nan with mixed precision
    if 'xlnet' in args.model:
        args.fp16 = False

    optimizer = create_optimizer(args=args, parameters=model.parameters())

    if master_process:
        writer = SummaryWriter(logdir=model_dir, flush_secs=100)

    # loading previous state of the training
    best_valid_loss, init_epoch, step = load_state(model_dir=model_dir,
                                                   model=model,
                                                   optimizer=optimizer,
                                                   logger=logger,
                                                   device=device)

    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    d_model = model.config.d_model if 'xlnet' in \
        args.model else model.config.n_embd

    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    world_size = int(os.environ.get('WORLD_SIZE', 1))

    train, valid, test = [(split, ceil(size / args.batch_size / world_size))
                          for split, size in datasets]

    # computing the sizes of the dataset splits
    train_dataset, num_train_steps = train
    valid_dataset, num_valid_steps = valid
    test_dataset, num_test_steps = test

    patience, skip, loss, accuracy = 0, 1, 0, 0

    set_lr_fn = partial(set_lr,
                        optimizer=optimizer,
                        schedule=args.schedule,
                        lr=args.lr,
                        warmup_steps=args.warmup_steps,
                        d_model=d_model)

    if master_process:
        # loading history for training logs
        history_path = join(model_dir, 'history.json')

        history = defaultdict(list)

        # NOTE the hardcoded values to keep track of
        # in the history
        metrics = ['loss', 'acc', 'ppl']
        headers = ['epoch'] + \
            ['train_' + m for m in metrics] + \
            ['valid_' + m for m in metrics]

        if exists(history_path):
            with open(history_path, 'r') as fh:
                history = json.load(fh)

    def print_results(results):
        """
        Prints the history to the standard output.
        """
        data = list(zip(*[history[h] for h in headers]))

        table = tabulate(tabular_data=data, headers=headers, floatfmt='.3f')

        # computing the tabular table string and
        # printing only the last element
        print(table.split('\n')[-1])

        msg = ', '.join('{}: {}'.format(n, r) for n, r in results.items())

        logger.info(msg)

    def record_history(results):
        """
        Records the results and prints them.
        """
        # saving history and handling unexpected
        # keyboard interrupt
        for header in headers:
            history[header].append(results[header])

        while True:
            try:
                with open(history_path, 'w') as fh:
                    json.dump(history, fh)
                break
            except KeyboardInterrupt:
                pass

    @contextmanager
    def skip_error():
        """
        Convenience function for skipping errors.
        """
        nonlocal skip

        try:
            # checking out of memory error and
            # proceeding if only a single GPU
            # is used for the training
            yield

        except RuntimeError as e:
            if 'out of memory' in str(e):
                if args.distributed:
                    raise e
                skip += 1

    def reduce_tensor(tensor):
        """
        Averages a tensor across gpus.
        """
        reduced = tensor.clone()
        all_reduce(reduced, op=ReduceOp.SUM)
        reduced /= world_size

        return reduced

    def forward_step(batch):
        """
        Applies forward pass with the given batch.
        """
        inputs, targets = batch

        outputs = model(inputs, half=args.fp16)

        # converting targets from ndarray
        targets = torch.as_tensor(targets)
        targets = targets.long().to(device)

        loss, acc, ppl = compute_loss(outputs=outputs,
                                      targets=targets,
                                      ignore_idx=pad_idx)

        if args.distributed:
            # reducing accuracy accross devices
            # for more accurate logging
            acc = reduce_tensor(acc)

        return loss, acc.item(), ppl

    def train_step(batch):
        """
        Performs a single step of training.
        """
        nonlocal step, skip

        loss, acc, ppl = forward_step(batch)

        if torch.isnan(loss).item():
            # during distributed training NaN
            # values are not handled
            if args.distributed:
                raise ValueError('NaN values encountered.')

            logger.debug('skipping step (nan)')
            # returning None values when a NaN loss
            # is encountered and skipping backprop
            # so model grads will not be corrupted

            skip += 1
            return None, None

        loss /= args.grad_accum_steps

        backward(loss)

        if args.clip_grad is not None:
            clip_grad_norm(args.clip_grad)

        if step % args.grad_accum_steps == 0:
            set_lr_fn(step)
            optimizer.step()
            optimizer.zero_grad()

        if args.distributed:
            # reducing loss accross devices for
            # more accurate logging
            loss = reduce_tensor(loss)

        step += 1

        return {'loss': loss.item(), 'acc': acc, 'ppl': ppl}

    def backward(loss):
        """
        Backpropagates the loss in either mixed or
        normal precision mode.
        """
        # cuda is required for mixed precision training.
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled:
                scaled.backward()
        else:
            loss.backward()

    def clip_grad_norm(max_norm):
        """
        Applies gradient clipping.
        """
        if args.fp16:
            clip_grad_norm_(amp.master_params(optimizer), max_norm)
        else:
            clip_grad_norm_(model.parameters(), max_norm)

    def evaluate(dataset, num_steps):
        """
        Constructs a validation loader and evaluates
        the model.
        """
        loop = tqdm(dataset(),
                    'eval',
                    num_steps,
                    False,
                    disable=not master_process)

        model.eval()

        for batch in loop:
            with skip_error():
                loss, accuracy, ppl = forward_step(batch)

                loop.set_postfix(
                    OrderedDict(loss=loss.item(), ppl=ppl, acc=accuracy))

                yield loss.item(), accuracy, ppl

    def save_state(name):
        """
        Saves the model and optimizer state.
        """
        model_path = join(model_dir, name + '.pt')

        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_valid_loss': best_valid_loss,
            'valid_loss': valid_loss,
            'epoch': epoch + 1,
            'step': step
        }

        logger.info('Saving model to {}'.format(model_path))
        # making sure the model saving is not left in a
        # corrupted state after a keyboard interrupt
        while True:
            try:
                torch.save(state, model_path)
                break
            except KeyboardInterrupt:
                pass

    if master_process:
        train_args = vars(args)
        logger.info(str(train_args))

        print()
        print(tabulate(train_args.items(), tablefmt='presto'))
        print()

    try:
        # initializing cuda buffer to avoid OOM errors
        dummy_batch = create_dummy_batch(args, ignore_idx=pad_idx)

        train_step(dummy_batch)

    except (RuntimeError, ValueError) as e:
        if 'out of memory' in str(e):
            msg = 'Not enough memory, there might ' + \
                'be several out of memory error during ' + \
                'training. To avoid this lower ' + \
                'the `--batch_size` or `--max_len`'

            if not args.grad_ckpt:
                msg += ', use the `--checkpointed` flag'

            if not APEX_INSTALLED:
                msg += ' or install apex for fp16 precision'

            logger.info(msg + '.')

        if args.distributed:
            return

    # creating table of history with correctly
    # arranged values for each header
    if master_process:
        table = list(zip(*[history[h] for h in headers]))
        print(tabulate(table, headers, floatfmt='.3f'))

    for epoch in range(init_epoch, args.max_epochs):
        # running training loop
        loop = tqdm(train_dataset(),
                    'train {}'.format(epoch),
                    num_train_steps,
                    False,
                    disable=not master_process)

        train_metrics = defaultdict(list)

        model.train()

        for batch in loop:
            with skip_error():
                results = train_step(batch)

                loss = results['loss']
                if master_process and loss is not None:
                    # adding the results to history
                    # and logging them to tensorboard
                    for metric, value in results.items():
                        train_metrics[metric].append(value)

                        if value == float('inf'):
                            value = 1e30

                        writer.add_scalar('train/' + metric, value, step)

                loop.set_postfix(OrderedDict(**results, skip=skip))

        train_metrics = {
            'train_' + metric: mean(values) if len(values) > 0 else 0.0
            for metric, values in train_metrics.items()
        }

        with torch.no_grad():
            valid_metrics = zip(
                *evaluate(dataset=valid_dataset, num_steps=num_valid_steps))

        valid_loss, valid_acc, valid_ppl = [
            mean(values) if len(values) > 0 else 0.0
            for values in valid_metrics
        ]

        # switching back to training
        model.train()

        if master_process:
            results = {'epoch': epoch}

            results.update(train_metrics)

            results.update({
                'valid_loss': valid_loss,
                'valid_acc': valid_acc,
                'valid_ppl': valid_ppl
            })

            record_history(results)
            print_results(results)

            # converting ppl to a large number so tensorboard
            # will not throw any warnings during training
            if valid_ppl == float('inf'):
                valid_ppl = 1e30

            # logging to tensorboard
            writer.add_scalar('val/loss', valid_loss, step)
            writer.add_scalar('val/acc', valid_acc, step)
            writer.add_scalar('val/ppl', valid_ppl, step)

        if master_process:
            save_state(name='last')

        if valid_loss < best_valid_loss:
            patience = 0
            best_valid_loss = valid_loss

            if master_process:
                save_state(name='best')

        else:
            patience += 1
            if patience == args.patience:
                # terminate when max patience
                # level is hit
                break

        if step == args.total_steps:
            break

    if master_process:
        writer.close()

    with torch.no_grad():
        test_metrics = zip(
            *evaluate(dataset=test_dataset, num_steps=num_test_steps))

    test_loss, test_acc, test_ppl = [
        mean(values) if len(values) > 0 else 0.0 for values in test_metrics
    ]

    if master_process:
        logger.info('test loss: {:.4}'.format(test_loss))