예제 #1
0
    def _save(self, tmp_checkpoint_dir):
        if self.checkpoint_to_load or not self.save_checkpoint_at_end:
            return {}

        # NOTE: see comment at ReportModel() call site about model size.
        if self.fixed_ordering is None:
            if self.seed is not None:
                PATH = 'models/{}-{:.1f}MB-model{:.3f}-{}-{}epochs-seed{}.pt'.format(
                    self.dataset, self.mb, self.model.model_bits,
                    self.model.name(), self.epoch, self.seed)
            else:
                PATH = 'models/{}-{:.1f}MB-model{:.3f}-{}-{}epochs-seed{}-{}.pt'.format(
                    self.dataset, self.mb, self.model.model_bits,
                    self.model.name(), self.epoch, self.seed, time.time())
        else:
            PATH = 'models/{}-{:.1f}MB-model{:.3f}-{}-{}epochs-seed{}-order{}.pt'.format(
                self.dataset, self.mb, self.model.model_bits,
                self.model.name(), self.epoch, self.seed,
                str(self.order_seed) if self.order_seed is not None else
                '_'.join(map(str, self.fixed_ordering))[:60])

        if self.dataset == 'tpcds':
            tuples_seen = self.bs * self.max_steps * self.epochs
            PATH = PATH.replace(
                '-seed', '-{}tups-seed'.format(utils.HumanFormat(tuples_seen)))

            if len(self.join_tables) == 1:
                PATH = PATH.replace('tpcds',
                                    'indep-{}'.format(self.join_tables[0]))

        torch.save(self.model.state_dict(), PATH)
        wandb.save(PATH)
        print('Saved to:', PATH)
        return {'path': PATH}
예제 #2
0
def run_epoch(split,
              model,
              opt,
              train_data,
              val_data=None,
              batch_size=100,
              upto=None,
              epoch_num=None,
              epochs=1,
              verbose=False,
              log_every=10,
              return_losses=False,
              table_bits=None,
              warmups=1000,
              loader=None,
              constant_lr=None,
              use_meters=True,
              summary_writer=None,
              lr_scheduler=None,
              custom_lr_lambda=None,
              label_smoothing=0.0):
    torch.set_grad_enabled(split == 'train')
    model.train() if split == 'train' else model.eval()
    dataset = train_data if split == 'train' else val_data
    losses = []

    if loader is None:
        loader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=(split == 'train'))

    # How many orderings to run for the same batch?
    nsamples = 1
    if hasattr(model, 'orderings'):
        nsamples = len(model.orderings)
        if verbose:
            print('setting nsamples to', nsamples)

    dur_meter = train_utils.AverageMeter('dur',
                                         lambda v: '{:.0f}s'.format(v),
                                         display_average=False)
    lr_meter = train_utils.AverageMeter('lr', ':.5f', display_average=False)
    tups_meter = train_utils.AverageMeter('tups',
                                          utils.HumanFormat,
                                          display_average=False)
    loss_meter = train_utils.AverageMeter('loss (bits/tup)', ':.2f')
    train_throughput = train_utils.AverageMeter('tups/s',
                                                utils.HumanFormat,
                                                display_average=False)
    batch_time = train_utils.AverageMeter('sgd_ms', ':3.1f')
    data_time = train_utils.AverageMeter('data_ms', ':3.1f')
    progress = train_utils.ProgressMeter(upto, [
        batch_time,
        data_time,
        dur_meter,
        lr_meter,
        tups_meter,
        train_throughput,
        loss_meter,
    ])

    begin_time = t1 = time.time()

    for step, xb in enumerate(loader):
        data_time.update((time.time() - t1) * 1e3)

        if split == 'train':
            if isinstance(dataset, data.IterableDataset):
                # Can't call len(loader).
                global_steps = upto * epoch_num + step + 1
            else:
                global_steps = len(loader) * epoch_num + step + 1

            if constant_lr:
                lr = constant_lr
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            elif custom_lr_lambda:
                lr_scheduler = None
                lr = custom_lr_lambda(global_steps)
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            elif lr_scheduler is None:
                t = warmups
                if warmups < 1:  # A ratio.
                    t = int(warmups * upto * epochs)

                d_model = model.embed_size
                lr = (d_model**-0.5) * min(
                    (global_steps**-.5), global_steps * (t**-1.5))
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            else:
                # We'll call lr_scheduler.step() below.
                lr = opt.param_groups[0]['lr']

        if upto and step >= upto:
            break

        if isinstance(xb, list):
            # This happens if using data.TensorDataset.
            assert len(xb) == 1, xb
            xb = xb[0]

        xb = xb.float().to(train_utils.get_device(), non_blocking=True)

        # Forward pass, potentially through several orderings.
        xbhat = None
        model_logits = []
        num_orders_to_forward = 1
        if split == 'test' and nsamples > 1:
            # At test, we want to test the 'true' nll under all orderings.
            num_orders_to_forward = nsamples

        for i in range(num_orders_to_forward):
            if hasattr(model, 'update_masks'):
                # We want to update_masks even for first ever batch.
                model.update_masks()

            model_out = model(xb)
            model_logits.append(model_out)
            if xbhat is None:
                xbhat = torch.zeros_like(model_out)
            xbhat += model_out

        if num_orders_to_forward == 1:
            loss = model.nll(xbhat, xb, label_smoothing=label_smoothing).mean()
        else:
            # Average across orderings & then across minibatch.
            #
            #   p(x) = 1/N sum_i p_i(x)
            #   log(p(x)) = log(1/N) + log(sum_i p_i(x))
            #             = log(1/N) + logsumexp ( log p_i(x) )
            #             = log(1/N) + logsumexp ( - nll_i (x) )
            #
            # Used only at test time.
            logps = []  # [batch size, num orders]
            assert len(model_logits) == num_orders_to_forward, len(
                model_logits)
            for logits in model_logits:
                # Note the minus.
                logps.append(
                    -model.nll(logits, xb, label_smoothing=label_smoothing))
            logps = torch.stack(logps, dim=1)
            logps = logps.logsumexp(dim=1) + torch.log(
                torch.tensor(1.0 / nsamples, device=logps.device))
            loss = (-logps).mean()

        losses.append(loss.detach().item())

        if split == 'train':
            opt.zero_grad()
            loss.backward()
            l2_grad_norm = TotalGradNorm(model.parameters())

            opt.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            loss_bits = loss.item() / np.log(2)

            # Number of tuples processed in this epoch so far.
            ntuples = (step + 1) * batch_size
            if use_meters:
                dur = time.time() - begin_time
                lr_meter.update(lr)
                tups_meter.update(ntuples)
                loss_meter.update(loss_bits)
                dur_meter.update(dur)
                train_throughput.update(ntuples / dur)

            if summary_writer is not None:
                wandb.log({
                    'train/lr': lr,
                    'train/tups': ntuples,
                    'train/tups_per_sec': ntuples / dur,
                    'train/nll': loss_bits,
                    'train/global_step': global_steps,
                    'train/l2_grad_norm': l2_grad_norm,
                })
                summary_writer.add_scalar('train/lr',
                                          lr,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/tups',
                                          ntuples,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/tups_per_sec',
                                          ntuples / dur,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/nll',
                                          loss_bits,
                                          global_step=global_steps)

            if step % log_every == 0:
                if table_bits:
                    print(
                        'Epoch {} Iter {}, {} entropy gap {:.4f} bits (loss {:.3f}, data {:.3f}) {:.5f} lr, {} tuples seen ({} tup/s)'
                        .format(
                            epoch_num, step, split,
                            loss.item() / np.log(2) - table_bits,
                            loss.item() / np.log(2), table_bits, lr,
                            utils.HumanFormat(ntuples),
                            utils.HumanFormat(ntuples /
                                              (time.time() - begin_time))))
                elif not use_meters:
                    print(
                        'Epoch {} Iter {}, {} loss {:.3f} bits/tuple, {:.5f} lr'
                        .format(epoch_num, step, split,
                                loss.item() / np.log(2), lr))

        if verbose:
            print('%s epoch average loss: %f' % (split, np.mean(losses)))

        batch_time.update((time.time() - t1) * 1e3)
        t1 = time.time()
        if split == 'train' and step % log_every == 0 and use_meters:
            progress.display(step)

    if return_losses:
        return losses
    return np.mean(losses)