Beispiel #1
0
def main(args):
    train, dev, test, words = read_corpus(args.data)
    dev_, test_ = dev, test
    # train = create_batches(train, args.batch_size)
    dev = create_batches(dev, args.batch_size)
    test = create_batches(test, args.batch_size)

    model = Model(words, args)
    model.cuda()
    flop.make_projected_linear_with_mask(model.rnn, in_place=True)
    if args.load:
        model.load_state_dict(torch.load(args.load))

    model.cuda()
    dev = create_batches(dev_, 1)
    test = create_batches(test_, 1)
    dev_ppl, dev_loss = eval_model(model, dev)
    test_ppl, test_loss = eval_model(model, test)
    sys.stdout.write("dev_bpc={:.3f}  test_bpc={:.3f}\n".format(
        np.log2(dev_ppl), np.log2(test_ppl)))
Beispiel #2
0
        sample_softmax=args.sample_softmax,
    )
    # in place substituion of linear ops into projectedlinear
    flop.make_projected_linear(model.layers, in_place=True)
    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

if args.prune:
    # load pre-trained model and insert HardConcrete modules
    model.apply(update_dropout)
    model.apply(update_dropatt)
    state_dict = torch.load(args.prune_load, map_location='cpu')
    model.load_state_dict(state_dict)
    flop.make_projected_linear_with_mask(model.layers, in_place=True)
    print(model.layers)

args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
print("total params: {}".format(args.n_all_param))
print("total non-emb params: {}".format(args.n_nonemb_param))

if args.fp16:
    model = model.half()

if args.multi_gpu:
    model = model.to(device)
    if args.gpu0_bsz >= 0:
        para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                          model,
def main(args):
    log_path = "{}_{}".format(args.log, random.randint(1, 100))
    train_writer = SummaryWriter(log_dir=log_path + "/train")
    dev_writer = SummaryWriter(log_dir=log_path + "/dev")

    train, dev, test, words = read_corpus(args.data)
    dev_, test_ = dev, test
    train = create_batches(train, args.batch_size)
    dev = create_batches(dev, args.batch_size)
    test = create_batches(test, args.batch_size)

    model = Model(words, args)
    if args.load:
        model.load_state_dict(torch.load(args.load))
    model.cuda()
    print(model)
    print("vocab size: {}".format(model.n_V))

    lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps**
                                                            1.5)
    if args.prune:
        # in place substituion of linear ops in SRU
        flop.make_projected_linear_with_mask(model.rnn,
                                             in_place=True,
                                             init_zero=True)
        model.cuda()
        print("model after inserting masks:")
        print(model)
        mask_params = list(flop.get_projected_linear_masks(model))
        optimizer_pm = Adam(mask_params, lr=0.001, weight_decay=0)
        num_masks_params = sum(x.numel() for x in mask_params)
        print("num of mask paramters: {}".format(num_masks_params))
        pm_linear_modules = flop.get_projected_linear_with_mask_modules(model)
        num_prunable_params = sum(m.num_prunable_parameters()
                                  for m in pm_linear_modules)
        print("num of prunable paramters: {}".format(num_prunable_params))
        mask_param_names = [
            i[0] for i in model.named_parameters()
            if i[1].requires_grad and "mask" in i[0]
        ]
        pruner = flop.NervanaPruner(
            model,
            subpruners={
                "agppruner": {
                    "class": "AutomatedGradualPruner",
                    "initial_sparsity": 0.05,
                    "weights": mask_param_names,
                    "final_sparsity": args.prune_sparsity,
                    "starting_step": args.prune_start_epoch,
                    "ending_step": args.prune_end_epoch,
                    "frequency": 1,
                }
            },
        )
    else:
        args.prune_start_epoch = args.max_epoch

    all_non_mask_params = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and "mask" not in i[0]
    ]
    num_params = sum(x.numel() for x in all_non_mask_params if x.requires_grad)
    print("num of parameters: {}".format(num_params))

    nbatch = 1
    niter = 1
    best_dev = 1e8
    unroll_size = args.unroll_size
    batch_size = args.batch_size
    N = (len(train[0]) - 1) // unroll_size + 1
    criterion = nn.CrossEntropyLoss()

    model.zero_grad()
    if args.prune:
        optimizer_pm.zero_grad()

    emb_parameters = list(model.embedding_layer.parameters()) + list(
        model.output_layer.parameters())
    emb_optimizer = Adam(emb_parameters,
                         lr=lr * args.lr,
                         weight_decay=args.weight_decay)
    emb_optimizer.zero_grad()
    # Deactivate all parameters in the RNN
    m_parameters = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and "mask" not in i[0]
    ]
    optimizer = None
    if args.freeze_period:
        for p in m_parameters:
            p.requires_grad = False
    else:
        optimizer = Adam(m_parameters,
                         lr=lr * args.lr,
                         weight_decay=args.weight_decay)

    for epoch in range(args.max_epoch):
        start_prune = epoch >= args.prune_start_epoch
        if args.freeze_period and optimizer is None and start_prune:
            for p in mask_params:
                p.requires_grad = False
            for p in m_parameters:
                p.requires_grad = True
            optimizer = Adam(m_parameters,
                             lr=lr * args.lr,
                             weight_decay=args.weight_decay)

        start_time = time.time()
        model.train()
        hidden = model.init_hidden(batch_size)
        pruner.begin_step(epoch)

        for i in range(N):
            # start iter on the first batch
            if nbatch % args.update_param_freq == 1:
                pruner.begin_iter(epoch, niter, N // args.update_param_freq)

            x = train[0][i * unroll_size:(i + 1) * unroll_size]
            y = train[1][i * unroll_size:(i + 1) * unroll_size].view(-1)
            hidden.detach_()

            # language model forward and backward
            output, hidden = model(x, hidden)
            model_loss = criterion(output, y)
            expected_sparsity = 0
            l1_loss = 0

            # add lagrangian loss (regularization) when pruning
            if start_prune:
                # compute expected model size and sparsity
                expected_size = sum(
                    m.num_parameters(train=True) for m in pm_linear_modules)
                expected_sparsity = 1.0 - expected_size / num_prunable_params
                expected_sparsity = expected_sparsity.item()

                l1_loss_aggr = 0
                if args.l1_lambda > 0 and expected_sparsity < args.prune_sparsity:
                    for p in mask_params:
                        l1_loss_aggr += torch.sum(torch.abs(p))

                l1_loss = args.l1_lambda * l1_loss_aggr

            if args.l1_lambda > 0:
                loss = model_loss + l1_loss
            else:
                loss = model_loss

            (loss / args.update_param_freq).backward()
            model_loss = model_loss.item()
            l1_loss = l1_loss.item() if isinstance(l1_loss,
                                                   torch.Tensor) else l1_loss

            #  log training stats
            if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0:
                if args.prune:
                    train_writer.add_scalar("sparsity/expected_sparsity",
                                            expected_sparsity, niter)
                    if (niter - 1) % 3000 == 0:
                        for index, layer in enumerate(mask_params):
                            train_writer.add_histogram(
                                "log_alpha/{}".format(index),
                                layer,
                                niter,
                                bins="sqrt",
                            )
                # sys.stderr.write(
                #     "\r{:.4f} {:.2f}".format(
                #         model_loss, expected_sparsity,
                #     )
                # )
                train_writer.add_scalar("loss/lm_loss", model_loss, niter)
                train_writer.add_scalar("loss/l1_loss", l1_loss, niter)
                train_writer.add_scalar("loss/total_loss",
                                        model_loss + l1_loss, niter)
                train_writer.add_scalar(
                    "parameter_norm",
                    calc_norm([x.data for x in m_parameters]), niter)
                train_writer.add_scalar(
                    "gradient_norm",
                    calc_norm(
                        [x.grad for x in m_parameters if x.grad is not None]),
                    niter,
                )

            #  perform gradient decent every few number of backward()
            if nbatch % args.update_param_freq == 0:
                if args.clip_grad > 0:
                    torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad)
                if emb_optimizer is not None:
                    emb_optimizer.step()
                if optimizer is not None:
                    optimizer.step()
                if start_prune or args.freeze_period:
                    optimizer_pm.step()
                #  clear gradient
                model.zero_grad()
                if args.prune:
                    optimizer_pm.zero_grad()

                # End iter on the last batch
                pruner.end_iter(epoch, niter, N // args.update_param_freq)
                niter += 1

            if nbatch % args.log_period == 0 or i == N - 1:
                elapsed_time = (time.time() - start_time) / 60.0
                dev_ppl, dev_loss = eval_model(model, dev)
                dev_writer.add_scalar("loss/lm_loss", dev_loss, niter)
                dev_writer.add_scalar("bpc", np.log2(dev_ppl), niter)
                sparsity = 0
                if args.prune:
                    pruned_size = sum(
                        m.num_parameters(train=False)
                        for m in pm_linear_modules)
                    sparsity = 1.0 - pruned_size / num_prunable_params
                    # agp_sparsity = pruner.get_step_logs()
                    dev_writer.add_scalar("sparsity/hard_sparsity", sparsity,
                                          niter)
                    # dev_writer.add_scalar("sparsity/agp_sparsity", agp_sparsity, niter)
                    dev_writer.add_scalar("model_size/total_prunable",
                                          num_prunable_params, niter)
                    dev_writer.add_scalar("model_size/current_prunable",
                                          pruned_size, niter)
                    dev_writer.add_scalar("model_size/total", num_params,
                                          niter)
                    dev_writer.add_scalar(
                        "model_size/current",
                        num_params - num_prunable_params + pruned_size,
                        niter,
                    )
                sys.stdout.write(
                    "\rIter={}  train_loss={:.4f}  dev_loss={:.4f}"
                    "  dev_bpc={:.2f}  sparsity={:.2f}\teta={:.1f}m\t[{:.1f}m]\n"
                    .format(
                        niter,
                        loss,
                        dev_loss,
                        np.log2(dev_ppl),
                        sparsity,
                        elapsed_time * N / (i + 1),
                        elapsed_time,
                    ))
                checkpoint = copy_model(model)
                sys.stdout.write("\n")
                sys.stdout.flush()

            nbatch += 1
            if args.noam:
                niter_ = niter
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                emb_optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d
                                                                      **0.5)
            if args.noam and optimizer is not None:
                niter_ = niter - args.prune_start_epoch * N if args.freeze_period else niter
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d**
                                                                  0.5)
            # if args.noam and (start_prune or args.freeze_period):
            #     niter_ = niter if args.freeze_period else niter - args.prune_start_epoch * N
            #     lr = min(1.0 / (niter_ ** 0.5), niter_ / (args.warmup_steps ** 1.5))
            #     optimizer_pm.param_groups[0]["lr"] = lr * args.lr / (args.n_d ** 0.5)

        pruner.end_step(epoch)
        if args.save and (epoch + 1) % 10 == 0:
            torch.save(
                copy_model(model),
                "{}.{}.{:.3f}.pt".format(args.save, epoch + 1, sparsity))

    train_writer.close()
    dev_writer.close()

    model.load_state_dict(checkpoint)
    model.cuda()
    dev = create_batches(dev_, 1)
    test = create_batches(test_, 1)
    dev_ppl, dev_loss = eval_model(model, dev)
    test_ppl, test_loss = eval_model(model, test)
    sys.stdout.write("dev_bpc={:.3f}  test_bpc={:.3f}\n".format(
        np.log2(dev_ppl), np.log2(test_ppl)))
Beispiel #4
0
def main(args):

    if args.local_rank == 0:
        log_path = "{}_{}".format(args.log, random.randint(1, 100))
        train_writer = SummaryWriter(log_dir=log_path + "/train")
        dev_writer = SummaryWriter(log_dir=log_path + "/dev")

    # set up distributed training
    # torch.cuda.set_device(args.local_rank)
    device = 'cuda'
    # torch.distributed.init_process_group(backend="nccl")
    set_seed(1234)
    args.n_gpu = 1
    args.device = device
    local_rank = args.local_rank

    corpus = get_lm_corpus(args.data, "wt103")
    n_token = args.n_token = len(corpus.vocab)
    args.eval_batch_size = args.eval_batch_size or args.batch_size
    args.eval_unroll_size = args.eval_unroll_size or args.unroll_size
    unroll_size = args.unroll_size
    eval_unroll_size = args.eval_unroll_size
    batch_size = args.batch_size
    eval_batch_size = args.eval_batch_size
    # n_nodes = torch.cuda.device_count()
    # train = corpus.get_distributed_iterator('train', batch_size,
    #                                         unroll_size, n_nodes=n_nodes,
    #                                         rank=local_rank, device=device)
    train = corpus.get_iterator("train",
                                batch_size,
                                unroll_size,
                                device=device)
    dev = corpus.get_iterator("test",
                              eval_batch_size,
                              eval_unroll_size,
                              device=device)
    if local_rank == 0:
        print("vocab size: {}".format(n_token))

    model = Model(args)
    if args.load:
        model.load_state_dict(torch.load(args.load))
    lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps**
                                                            1.5)
    if args.prune:
        # in place substituion of linear ops in SRU
        flop.make_projected_linear_with_mask(model.rnn, in_place=True)
        model.embedding_layer = AdaptiveEmbeddingWithMask.from_module(
            model.embedding_layer)
        model.output_layer = AdaptiveLogSoftmaxWithMask.from_module(
            model.output_layer)
        # tie weights again
        model.tie_weights()
        model.to(device)
        mask_params = list(flop.get_projected_linear_masks(model.rnn))
        for param in model.embedding_layer.masks:
            mask_params.append(param)

        optimizer_pm = torch.optim.Adam(mask_params,
                                        lr=lr * args.prune_lr,
                                        weight_decay=0)

        mask_modules = flop.get_projected_linear_with_mask_modules(model) + [
            model.embedding_layer
        ]

        num_mask_params = sum(x.numel() for x in mask_params)
        num_prunable_params = sum(m.num_prunable_parameters()
                                  for m in mask_modules)
        if local_rank == 0:
            print("num of mask parameters: {}".format(num_mask_params))
            print("num of prunable parameters: {}".format(num_prunable_params))

        mask_param_names = [
            i[0] for i in model.named_parameters()
            if i[1].requires_grad and "mask" in i[0]
        ]
        pruner = flop.NervanaPruner(
            model,
            subpruners={
                "agppruner": {
                    "class": "AutomatedGradualPruner",
                    "initial_sparsity": 0.05,
                    "weights": mask_param_names,
                    "final_sparsity": args.prune_sparsity,
                    "starting_step": args.prune_start_epoch + 1,
                    "ending_step": args.prune_end_epoch,
                    "frequency": 1,
                }
            },
        )
    else:
        model.to(device)
        args.prune_start_epoch = args.max_epoch

    m_parameters = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and "mask" not in i[0]
    ]
    optimizer = torch.optim.Adam(m_parameters,
                                 lr=lr * args.lr,
                                 weight_decay=args.weight_decay)
    num_params = sum(x.numel() for x in m_parameters if x.requires_grad)

    model_ = model
    model = nn.DataParallel(model, dim=1).to('cuda')
    # model = torch.nn.parallel.DistributedDataParallel(
    #     model, dim=1, device_ids=[local_rank], output_device=local_rank,
    # )

    nbatch = 1
    niter = 1
    best_dev = 1e8
    unroll_size = args.unroll_size
    batch_size = args.batch_size
    N = train.n_batch
    checkpoint = None
    if local_rank == 0:
        print(model)
        print("num of parameters: {}".format(num_params))
        print("num of mini-batches: {}".format(N))

    model.zero_grad()
    if args.prune:
        optimizer_pm.zero_grad()

    for epoch in range(args.max_epoch):
        start_time = time.time()
        model.train()
        hidden = model_.init_hidden(batch_size)
        start_prune = epoch >= args.prune_start_epoch
        i = 0

        pruner.begin_step(epoch)
        for x, y, seq_len in train:
            # start iter on the first batch
            if nbatch % args.update_param_freq == 1:
                pruner.begin_iter(epoch, niter, N // args.update_param_freq)

            i += 1
            hidden.detach_()

            # language model forward and backward
            model_loss, hidden = model(x, y, hidden)
            l1_loss = 0
            expected_sparsity = 0

            # add lagrangian loss (regularization) when pruning
            if start_prune:
                # compute expected model size and sparsity
                expected_size = sum(
                    m.num_parameters(train=True) for m in mask_modules)
                expected_sparsity = 1.0 - expected_size / num_prunable_params
                expected_sparsity = expected_sparsity.item()

                l1_loss_aggr = 0
                if args.l1_lambda > 0 and expected_sparsity < args.prune_sparsity:
                    for p in mask_params:
                        l1_loss_aggr += torch.sum(torch.abs(p))

                l1_loss = args.l1_lambda * l1_loss_aggr

            model_loss = model_loss.mean()
            if args.l1_lambda > 0:
                loss = model_loss + l1_loss
            else:
                loss = model_loss

            (loss / args.update_param_freq).backward()
            model_loss = model_loss.item()
            l1_loss = l1_loss.item() if isinstance(l1_loss,
                                                   torch.Tensor) else l1_loss

            #  log training stats
            if (local_rank == 0 and (niter - 1) % 100 == 0
                    and nbatch % args.update_param_freq == 0):
                if args.prune:
                    train_writer.add_scalar("sparsity/expected_sparsity",
                                            expected_sparsity, niter)
                    train_writer.add_scalar("loss/l1_loss", l1_loss, niter)
                sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format(
                    math.exp(model_loss),
                    l1_loss,
                    expected_sparsity,
                    (time.time() - start_time) / 60.0 / (i + 1) * (N - i - 1),
                ))
                train_writer.add_scalar("loss/ppl", math.exp(model_loss),
                                        niter)
                train_writer.add_scalar("loss/lm_loss", model_loss, niter)
                train_writer.add_scalar("loss/total_loss",
                                        model_loss + l1_loss, niter)
                train_writer.add_scalar(
                    "parameter_norm",
                    calc_norm([x.data for x in m_parameters]), niter)
                train_writer.add_scalar(
                    "gradient_norm",
                    calc_norm(
                        [x.grad for x in m_parameters if x.grad is not None]),
                    niter,
                )

            #  perform gradient decent every few number of backward()
            if nbatch % args.update_param_freq == 0:
                if args.clip_grad > 0:
                    torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad)
                optimizer.step()
                if start_prune:
                    optimizer_pm.step()
                #  clear gradient
                model.zero_grad()
                if args.prune:
                    optimizer_pm.zero_grad()

                # End iter on the last batch
                pruner.end_iter(epoch, niter, N // args.update_param_freq)
                niter += 1

            if local_rank == 0 and (nbatch % args.log_period == 0 or i == N):
                elapsed_time = (time.time() - start_time) / 60.0
                dev_ppl, dev_loss = eval_model(model_, dev)
                dev_writer.add_scalar("loss/lm_loss", dev_loss, niter)
                dev_writer.add_scalar("loss/ppl", dev_ppl, niter)
                dev_writer.add_scalar("ppl", dev_ppl, niter)
                sparsity = 0
                if args.prune:
                    pruned_size = sum(
                        m.num_parameters(train=False) for m in mask_modules)
                    sparsity = 1.0 - pruned_size / num_prunable_params
                    # agp_sparsity = pruner.get_step_logs()
                    dev_writer.add_scalar("sparsity/hard_sparsity", sparsity,
                                          niter)
                    # dev_writer.add_scalar("sparsity/agp_sparsity", agp_sparsity, niter)
                    dev_writer.add_scalar("model_size/total_prunable",
                                          num_prunable_params, niter)
                    dev_writer.add_scalar("model_size/current_prunable",
                                          pruned_size, niter)
                    dev_writer.add_scalar("model_size/total", num_params,
                                          niter)
                    dev_writer.add_scalar(
                        "model_size/current",
                        num_params - num_prunable_params + pruned_size,
                        niter,
                    )
                    dev_writer.add_scalar(
                        "model_size/current_embedding",
                        model_.embedding_layer.num_parameters(train=False),
                        niter,
                    )
                    dev_writer.add_scalar(
                        "model_size/current_output_layer",
                        model_.output_layer.num_parameters(train=False),
                        niter,
                    )
                sys.stdout.write(
                    "\rnum_batches={}  lr={:.5f}  train_loss={:.4f}  dev_loss={:.4f}"
                    "  dev_bpc={:.2f}  sparsity={:.2f}\t[{:.1f}m]\n".format(
                        nbatch,
                        optimizer.param_groups[0]["lr"],
                        loss,
                        dev_loss,
                        dev_ppl,
                        sparsity,
                        elapsed_time,
                    ))
                checkpoint = copy_model(model_)
                sys.stdout.write("\n")
                sys.stdout.flush()

            nbatch += 1
            if args.noam:
                lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5))
                optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d**
                                                                  0.5)
            if args.noam and start_prune:
                niter_ = niter - args.prune_start_epoch * N
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                optimizer_pm.param_groups[0]["lr"] = lr * args.lr / (args.n_d**
                                                                     0.5)

        pruner.end_step(epoch)
        if local_rank == 0 and args.save and (epoch + 1) % 10 == 0:
            torch.save(
                checkpoint,
                "{}.{}.{}.pt".format(args.save, epoch + 1,
                                     int(dev_ppl)
                                     # sparsity
                                     ),
            )

    if local_rank == 0:
        train_writer.close()
        dev_writer.close()

        if checkpoint is not None:
            model_.load_state_dict(checkpoint)
            model_.to(device)
        # dev = create_batches(dev_, 1)
        # test = create_batches(test_, 1)
        test = corpus.get_iterator("test",
                                   eval_batch_size,
                                   eval_unroll_size,
                                   device=device)
        dev_ppl, dev_loss = eval_model(model_, dev)
        test_ppl, test_loss = eval_model(model_, test)
        sys.stdout.write("dev_ppl={:.3f}  test_ppl={:.3f}\n".format(
            dev_ppl, test_ppl))