Ejemplo n.º 1
0
def train(c):
    import distiller
    net = Transformer(c)

    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)

    step_lr = scheduler(c, opt, step)
    data_tr = SampleIterator(c,
                             c.train_batch,
                             split='valid' if c.debug else 'train')
    iter_tr = iter(data_tr)
    data_val = SequentialIterator(c, c.eval_batch, split='valid')
    data_test = SequentialIterator(c, c.eval_batch, split='test')

    print('Before quantization')
    tbl, sparsity = distiller.weights_sparsity_tbl_summary(
        net, return_total_sparsity=True)
    step_result = pd.Series(evaluate(c, data_val, net)).add_prefix('val_')
    step_result = step_result.append(
        pd.Series(evaluate(c, data_test, net)).add_prefix('test_'))
    step_result['sparsity'] = sparsity
    print(step_result)

    compression_scheduler = distiller.config.file_config(net, opt, c.compress)

    print('After initial quantization')
    s = Namespace(net=net, opt=opt, step=step)
    c.on_train_start(s)

    tbl, sparsity = distiller.weights_sparsity_tbl_summary(
        net, return_total_sparsity=True)
    step_result = pd.Series(evaluate(c, data_val, net)).add_prefix('val_')
    step_result = step_result.append(
        pd.Series(evaluate(c, data_test, net)).add_prefix('test_'))
    step_result['sparsity'] = sparsity
    print(step_result)

    npm = []
    for name, param in net.named_parameters():
        if param.dim() in [2, 4] and any(type in name
                                         for type in ['weight', 'bias']):
            npm.append((name, param, param.abs() == 0))

    best_val_loss = np.inf
    if s.results is not None and 'val_loss' in s.results.columns:
        best_val_loss = s.results['val_loss'].dropna().max()
    try:
        steps_per_epoch = c.step_eval
        while step < s.step_max:
            epoch = step // steps_per_epoch
            batch = step % steps_per_epoch

            if batch == 0:
                compression_scheduler.on_epoch_begin(epoch)
            compression_scheduler.on_minibatch_begin(epoch, batch,
                                                     steps_per_epoch)

            step_lr(step)

            x = to_torch(next(iter_tr), c.device).t()

            t_s = time()
            inputs, labels = x[:-1], x[1:]
            preds = net(inputs, labels)
            loss = preds['loss']

            compression_scheduler.before_backward_pass(epoch, batch,
                                                       steps_per_epoch, loss,
                                                       False)

            opt.zero_grad()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(),
                                           c.get('clip_grad', 0.5))

            compression_scheduler.before_parameter_optimization(
                epoch, batch, steps_per_epoch, opt)
            opt.step()
            for name, param, mask in npm:
                param.data[mask] = 0
            compression_scheduler.on_minibatch_end(epoch, batch,
                                                   steps_per_epoch)

            if (batch + 1) == steps_per_epoch:
                compression_scheduler.on_epoch_end(epoch)

            time_model = np.round(time() - t_s, 5)

            loss = from_torch(loss)
            perplexity = np.nan if loss > 5 else np.e**loss
            step_result = pd.Series(
                dict(
                    loss=loss,
                    perplexity=perplexity,
                    time=time_model,
                )).add_prefix('train_')
            step_result['lr'] = next(iter(opt.param_groups))['lr']

            s.step = step = step + 1
            if step % c.step_eval == 0:
                tbl, sparsity = distiller.weights_sparsity_tbl_summary(
                    net, return_total_sparsity=True)
                step_result = step_result.append(
                    pd.Series(evaluate(c, data_val, net)).add_prefix('val_'))
                step_result = step_result.append(
                    pd.Series(evaluate(c, data_test, net)).add_prefix('test_'))
                step_result['sparsity'] = sparsity
                s.record_step = step_result['val_loss'] < best_val_loss
                clear_gpu_memory()
            s.step_result = step_result
            c.on_step_end(s)
    except Exception as e:
        import traceback
        err = traceback.format_exc()
        if c.main:
            c.log(err)
        else:
            print(err)
    finally:
        c.on_train_end(s)
    return net, step
Ejemplo n.º 2
0
def train(c):
    c.setdefault(hebbian=False)
    net = eval(c.model)(c)

    emb_params = count_params(net.embed) + count_params(
        net.loss.projections) + count_params(net.loss.clusters)
    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)
    step_lr = scheduler(c, opt, step)

    if c.get('distill'):
        data_tr_distill = DistillationSampleIterator(c, c.train_batch)
        iter_tr_distill = iter(data_tr_distill)
    else:
        data_tr = SampleIterator(c,
                                 c.train_batch,
                                 split='valid' if c.debug else 'train')
        iter_tr = iter(data_tr)
    data_val = SequentialIterator(c, c.eval_batch, split='valid')

    s = Namespace(net=net, opt=opt, step=step)
    c.on_train_start(s)

    c.log('Embedding has %s parameters' % emb_params)

    if c.hebbian:
        counters = [
            torch.ones(end - start, dtype=torch.long, device=c.device)
            for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])
        ]
        temp_counters = [torch.zeros_like(x) for x in counters]

    best_val_loss = np.inf
    if s.results is not None and 'val_loss' in s.results.columns:
        best_val_loss = s.results['val_loss'].dropna().max()
    try:
        while step < s.step_max:
            step_lr(step)
            t_s = time()

            if c.get('distill'):
                hard_labels, soft_labels, soft_probs = next(iter_tr_distill)
                hard_labels = to_torch(hard_labels, c.device).t()

                soft_labels = to_torch(soft_labels, c.device).permute(1, 0,
                                                                      2)[1:]
                soft_probs = to_torch(soft_probs, c.device).permute(1, 0,
                                                                    2)[1:]

                inputs, hard_labels = hard_labels[:-1], hard_labels[1:]
                preds = net(inputs=inputs,
                            labels=hard_labels,
                            soft_labels=soft_labels,
                            soft_probs=soft_probs,
                            current_step=step)
            else:
                x = to_torch(next(iter_tr), c.device).t()
                inputs, labels = x[:-1], x[1:]
                preds = net(inputs, labels)
            loss = preds['loss']

            opt.zero_grad()
            if torch.isnan(loss):
                raise RuntimeError('Encountered nan loss during training')
            if c.opt_level == 'O0':
                loss.backward()
            else:
                with amp.scale_loss(loss, opt) as scaled_loss:
                    scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(),
                                           c.get('clip_grad', 0.5))
            opt.step()

            if c.hebbian:
                hebbian_weight_update(c, net, preds['hiddens'], counters,
                                      temp_counters)

            time_model = np.round(time() - t_s, 5)
            loss = from_torch(loss)
            perplexity = np.nan if loss > 5 else np.e**loss
            step_result = pd.Series(
                dict(
                    loss=loss,
                    perplexity=perplexity,
                    time=time_model,
                )).add_prefix('train_')
            step_result['lr'] = next(iter(opt.param_groups))['lr']
            if c.get('use_cache'):
                step_result['theta'] = from_torch(preds['theta'])
                step_result['lambda'] = from_torch(preds['lambda'])

            s.step = step = step + 1
            if step % c.step_eval == 0:
                step_result = step_result.append(
                    pd.Series(evaluate(c, data_val, net)).add_prefix('val_'))
                s.record_step = step_result['val_loss'] < best_val_loss
                clear_gpu_memory()
            s.step_result = step_result
            c.on_step_end(s)
    except Exception as e:
        import traceback
        err = traceback.format_exc()
        if c.main:
            c.log(err)
        else:
            print(err)
    finally:
        c.on_train_end(s)
Ejemplo n.º 3
0
def train(c, net, compression_scheduler=None):
    import distiller.apputils as apputils
    from distiller.data_loggers import TensorBoardLogger, PythonLogger
    msglogger = apputils.config_pylogger('logging.conf', None)
    tflogger = TensorBoardLogger(msglogger.logdir)
    tflogger.log_gradients = True
    pylogger = PythonLogger(msglogger)
    c.setdefault(hebbian=False)

    emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters)
    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)
    step_lr = scheduler(c, opt, step)
    data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train')
    iter_tr = iter(data_tr)
    data_val = SequentialIterator(c, c.eval_batch, split='valid')

    s = Namespace(net=net, opt=opt, step=step)
    c.on_train_start(s)

    c.log('Embedding has %s parameters' % emb_params)

    if c.get("steps_per_epoch"):
        steps_per_epoch = c.steps_per_epoch
    else:
        steps_per_epoch = len(data_tr.tokens) // data_tr.bs // c.train_chunk
    print("#### steps per epoch %d ####" % steps_per_epoch)

    if c.hebbian:
        counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])]
        temp_counters = [torch.zeros_like(x) for x in counters]

    best_val_loss = np.inf
    if s.results is not None and 'val_loss' in s.results.columns:
        best_val_loss = s.results['val_loss'].dropna().max()
    try:
        while step < s.step_max:
            batch = step % steps_per_epoch
            epoch = step // steps_per_epoch
            if step % steps_per_epoch == 0:
                c.log("====> batch=%d, epoch=%d, step=%d" % (batch, epoch, step))
                if compression_scheduler:
                    compression_scheduler.on_epoch_begin(epoch)

            if compression_scheduler:
                compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)

            step_lr(step)

            x = to_torch(next(iter_tr), c.device).t()

            t_s = time()
            inputs, labels = x[:-1], x[1:]
            preds = net(inputs, labels)
            loss = preds['loss']

            if compression_scheduler:
                _  = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
                                                           minibatches_per_epoch=steps_per_epoch,
                                                           loss=loss, return_loss_components=False)

            opt.zero_grad()
            if torch.isnan(loss):
                raise RuntimeError('Encountered nan loss during training')
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5))
            opt.step()

            if c.hebbian:
                hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters)

            time_model = np.round(time() - t_s, 5)

            loss = from_torch(loss)
            perplexity = np.nan if loss > 5 else np.e ** loss
            step_result = pd.Series(dict(
                loss=loss,
                perplexity=perplexity,
                time=time_model
            )).add_prefix('train_')
            step_result['lr'] = next(iter(opt.param_groups))['lr']
            if c.use_cache:
                step_result['theta'] = preds['theta']
                step_result['lambda'] = preds['lambda'].item()

            if compression_scheduler:
                compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)

            if step % steps_per_epoch == 0:
                if compression_scheduler:
                    compression_scheduler.on_epoch_end(epoch)

            s.step = step = step + 1
            if step % c.step_eval == 0:
                distiller.log_weights_sparsity(net, epoch, loggers=[tflogger, pylogger])
                t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True)
                c.log("total sparsity: %.3lf" % total)

                step_result = step_result.append(
                    pd.Series(evaluate(c, data_val, net)).add_prefix('val_')
                )
                s.record_step = step_result['val_loss'] < best_val_loss
                clear_gpu_memory()
            s.step_result = step_result
            c.on_step_end(s)
    except Exception as e:
        import traceback
        err = traceback.format_exc()
        if c.main:
            c.log(err)
        else:
            print(err)
    finally:
        c.on_train_end(s)