예제 #1
0
def evaluate(c, data, net):
    clear_gpu_memory()
    was_training = net.training
    net.eval()
    
    t_s = time()
    with torch.no_grad():
        weights = []
        losses = []
        prevs = None

        for batch in data:
            x = to_torch(batch, c.device).t()
            inputs, labels = x[:-1], x[1:]

            preds = net.forward(inputs, labels, prevs=prevs)
            prevs = preds['state']
            losses.append(preds['loss'])
            weights.append(labels.size(0))
        weights = np.array(weights)
        weights = weights / weights.sum()
        loss = sum(x * w for x, w in zip(losses, weights))

    if c.distributed:
        gathered_losses = [torch.zeros_like(loss) for _ in range(c.world_size)]
        torch.distributed.all_gather(gathered_losses, loss)
        loss = sum(gathered_losses) / len(gathered_losses)
    if was_training:
        net.train()
    loss = from_torch(loss)
    perplexity = np.nan if loss > 5 else np.e ** loss
    return dict(loss=loss, perplexity=perplexity, time=np.round(time() - t_s))
예제 #2
0
def profile(c):
    c.setdefault(hebbian=False, distributed=False)
    net, step = config.var(device="cpu").load_model("max")
    print('profiling')
    net.eval()

    data_val = SequentialIterator(config, config.eval_batch, split="test")
    with torch.no_grad():
        for batch in data_val:
            x = to_torch(batch, c.device).t()

            inputs, labels = x[:-1], x[1:]

            macs = profile_macs(net, (inputs, labels))
            print("==> FLOPS: ", macs / config.eval_chunk * 2)

            print("==> Models size: ", count_params(net))

            exit(0)
예제 #3
0
def gen_soft_labels(c):
    c.setdefault(hebbian=False, distributed=False)
    net = get_net(c)
    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)

    print('generating soft labels...')
    data_gen_tr = SequentialIteratorGenSoft(c,
                                            c.get('gen_soft_batch'),
                                            split='train')
    # data_gen_tr = iter(data_gen_tr)
    clear_gpu_memory()
    net.eval()
    with torch.no_grad():
        i = 0
        for batch in tqdm(data_gen_tr):
            x = to_torch(batch, c.device).t()
            # print(x.size())
            # print(x[0:20])
            inputs, labels = x[:-1], x[1:]
            probs, _ = net(inputs, labels)

            # loss_hard = -torch.log(probs.gather(1, labels).squeeze(1)).mean()

            values, indices = torch.topk(probs, c.get('topk'), dim=1)

            indices_ = indices.cpu().numpy()
            values_ = values.cpu().numpy()
            labels_ = labels.cpu().numpy()
            # print(indices_[0:5])
            # print(labels_[0:5])
            # exit(0)

            if probs.size(0) != inputs.size(0):
                indices_ = indices_[-inputs.size(0):, :]
                values_ = values_[-inputs.size(0):, :]
                # labels_ = labels_[-inputs.size(0):, :]

            if i == 0:
                all_soft_indices = indices_
                all_soft_values = values_
            else:
                all_soft_indices = np.concatenate((all_soft_indices, indices_),
                                                  axis=0)
                all_soft_values = np.concatenate((all_soft_values, values_),
                                                 axis=0)

            # print(all_soft_indices.shape)
            # print(all_soft_values.shape)

            i += 1
            # if i > 100:
            #     break
        all_soft_indices = np.concatenate(
            (all_soft_indices[0:1, :], all_soft_indices), axis=0)
        all_soft_values = np.concatenate(
            (all_soft_values[0:1, :], all_soft_values), axis=0)
        np.save(
            c.get('file_out_path') + 'all_soft_indices' +
            str(c.get('worker')) + '.npy', all_soft_indices)
        np.save(
            c.get('file_out_path') + 'all_soft_values' + str(c.get('worker')) +
            '.npy', all_soft_values)

        in_indices = np.load(
            c.get('file_out_path') + 'all_soft_indices' +
            str(c.get('worker')) + '.npy')

        cnt = 0.
        # print(in_indices.shape)
        # print(len(data.tokens))
        for k in range(len(data_gen_tr.tokens)):
            # print(data.tokens[k])
            # print(in_indices[k])
            if data_gen_tr.tokens[k] in in_indices[k]:
                cnt += 1
        print(cnt / len(data_gen_tr.tokens))
예제 #4
0
def train(c):
    c.setdefault(hebbian=False)
    net = get_net(c)

    emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters)
    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)
    if c.get('distillation_teacher') == 'file':
        data_tr_distill = DistillationSampleIterator(c, c.train_batch, split='train')
        iter_tr_distill = iter(data_tr_distill)

    step_lr = scheduler(c, opt, step)
    data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train')
    iter_tr = iter(data_tr)
    data_val = SequentialIterator(c, c.eval_batch, split='valid')

    s = Namespace(net=net, opt=opt, step=step)
    c.on_train_start(s)

    c.log('Embedding has %s parameters' % emb_params)

    if c.hebbian:
        counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])]
        temp_counters = [torch.zeros_like(x) for x in counters]

    best_val_loss = np.inf
    if s.results is not None and 'val_loss' in s.results.columns:
        best_val_loss = s.results['val_loss'].dropna().max()
    try:
        while step < s.step_max:
            step_lr(step)

            if c.get('distillation_teacher') == 'file':
                x_hard_labels, x_soft_labels, x_soft_probs = next(iter_tr_distill)

                x_hard_labels = to_torch(x_hard_labels, c.device).t()

                x_soft_labels = to_torch(x_soft_labels, c.device)
                x_soft_labels = x_soft_labels.permute(1, 0, 2)

                x_soft_probs = to_torch(x_soft_probs, c.device)
                x_soft_probs = x_soft_probs.permute(1, 0, 2)

                inputs, hard_labels = x_hard_labels[:-1], x_hard_labels[1:]
                soft_labels = x_soft_labels[1:]
                soft_probs = x_soft_probs[1:]

                t_s = time()

                preds = net(inputs=inputs, labels=hard_labels, soft_labels=soft_labels, soft_probs=soft_probs,
                            is_distilling=True, current_step=step)
                loss = preds['loss']
                total_loss = loss
                extras = {}

            else:
                x = to_torch(next(iter_tr), c.device).t()

                t_s = time()
                inputs, labels = x[:-1], x[1:]
                preds = net(inputs, labels)
                loss = preds['loss']
                if c.model_class == 'UniversalTransformer':
                    act_loss = preds['act_loss']
                    total_loss = act_loss + loss
                    extras = dict(act_loss=from_torch(act_loss), n_updates=from_torch(preds['n_updates'].mean()))
                else:
                    total_loss = loss
                    extras = {}

            opt.zero_grad()
            if torch.isnan(total_loss):
                raise RuntimeError('Encountered nan loss during training')
            with amp.scale_loss(total_loss, opt) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5))
            opt.step()

            if c.hebbian:
                hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters)

            time_model = np.round(time() - t_s, 5)

            loss = from_torch(loss)
            perplexity = np.nan if loss > 5 else np.e ** loss
            step_result = pd.Series(Dict(
                loss=loss,
                perplexity=perplexity,
                time=time_model,
                **extras
            )).add_prefix('train_')
            step_result['lr'] = next(iter(opt.param_groups))['lr']
            step_result['theta'] = from_torch(preds['theta'])
            step_result['lambda'] = from_torch(preds['lambda'])

            s.step = step = step + 1
            if step % c.step_eval == 0:
                step_result = step_result.append(
                    pd.Series(evaluate(c, data_val, net)).add_prefix('val_')
                )
                s.record_step = step_result['val_loss'] < best_val_loss
                clear_gpu_memory()
            s.step_result = step_result
            c.on_step_end(s)
    except Exception as e:
        import traceback
        err = traceback.format_exc()
        if c.main:
            c.log(err)
        else:
            print(err)
    finally:
        c.on_train_end(s)
예제 #5
0
def train(c, net, compression_scheduler=None):
    c.setdefault(hebbian=False)
    assert not c.distributed and not c.parallel

    emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters)
    opt = get_opt(c, net)
    net, opt, step = c.init_model(net, opt=opt, step='max', train=True)
    step_lr = scheduler(c, opt, step)
    data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train')
    iter_tr = iter(data_tr)
    data_val = SequentialIterator(c, c.eval_batch, split='valid')

    s = Namespace(net=net, opt=opt, step=step)
    c.on_train_start(s)

    c.log('Embedding has %s parameters' % emb_params)

    if c.get("steps_per_epoch"):
        steps_per_epoch = c.steps_per_epoch
    else:
        steps_per_epoch = len(data_tr.tokens) // data_tr.bs // c.train_chunk
    print("#### steps per epoch %d ####" % steps_per_epoch)
    

    if c.hebbian:
        counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])]
        temp_counters = [torch.zeros_like(x) for x in counters]

    best_val_loss = np.inf
    if s.results is not None and 'val_loss' in s.results.columns:
        best_val_loss = s.results['val_loss'].dropna().max()
    try:
        while step < s.step_max:
            batch = step % steps_per_epoch
            epoch = step // steps_per_epoch 
            if step % steps_per_epoch == 0:
                c.log("====> batch=%d, epoch=%d, step=%d" % (batch, epoch, step))
                if compression_scheduler:
                    compression_scheduler.on_epoch_begin(epoch)

            if compression_scheduler:
                compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)


            step_lr(step)

            x = to_torch(next(iter_tr), c.device).t()

            t_s = time()
            inputs, labels = x[:-1], x[1:]
            preds = net(inputs, labels)
            loss = preds['loss']
            if c.model_class == 'UniversalTransformer':
                act_loss = preds['act_loss']
                total_loss = act_loss + loss
                extras = dict(act_loss=from_torch(act_loss), n_updates=from_torch(preds['n_updates'].mean()))
            else:
                total_loss = loss
                extras = {}

            if compression_scheduler:
                _  = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch,
                                                           minibatches_per_epoch=steps_per_epoch,
                                                           loss=total_loss, return_loss_components=False)

            opt.zero_grad()
            if torch.isnan(total_loss):
                raise RuntimeError('Encountered nan loss during training')
            with amp.scale_loss(total_loss, opt) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5))
            opt.step()

            if c.hebbian:
                hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters)

            time_model = np.round(time() - t_s, 5)

            loss = from_torch(loss)
            perplexity = np.nan if loss > 5 else np.e ** loss
            step_result = pd.Series(Dict(
                loss=loss,
                perplexity=perplexity,
                time=time_model,
                **extras
            )).add_prefix('train_')
            step_result['lr'] = next(iter(opt.param_groups))['lr']
            step_result['theta'] = preds['theta']
            step_result['lambda'] = preds['lambda'].item()

            if compression_scheduler:
                compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch)

            if step % steps_per_epoch == 0:
                if compression_scheduler:
                    compression_scheduler.on_epoch_end(epoch)


            if step % c.step_eval == 0:
                distiller.log_weights_sparsity(net, epoch, loggers=[tflogger, pylogger])
                t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True)
                c.log("total sparsity: %.3lf" % total)

                step_result = step_result.append(
                    pd.Series(evaluate(c, data_val, net)).add_prefix('val_')
                )
                s.record_step = step_result['val_loss'] < best_val_loss
                clear_gpu_memory()
            s.step_result = step_result

            """
            # sanity check
            t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True)
            val_result = evaluate(c, data_val, net)
            c.log("#### before on step end: sparsity %.3lf | val result: %s" % (total, val_result))
            """


            c.on_step_end(s)

            """
            # sanity check
            t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True)
            val_result = evaluate(c, data_val, net)
            c.log("#### after on step end: sparsity %.3lf | val result: %s" % (total, val_result))
            """


            c.log("@@@@ step %d end @@@@" % step)


            s.step = step = step + 1
    except Exception as e:
        import traceback
        err = traceback.format_exc()
        if c.main:
            c.log(err)
        else:
            print(err)
    finally:
        c.on_train_end(s)