예제 #1
0
파일: train.py 프로젝트: taoleicn/flop

if args.restart:
    with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
        model = torch.load(f)
    if not args.fp16:
        model = model.float()
    model.apply(update_dropout)
    model.apply(update_dropatt)
elif args.prune:
    # load pre-trained model and insert HardConcrete modules
    model = torch.load(args.prune_load)
    model.apply(update_dropout)
    model.apply(update_dropatt)
    flop.make_hard_concrete(model.layers,
                            in_place=True,
                            init_mean=args.prune_init_mean)
    print(model.layers)
else:
    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
예제 #2
0
def main(args):
    log_path = "{}_{}".format(args.log, random.randint(1, 100))
    train_writer = SummaryWriter(log_dir=log_path + "/train")
    dev_writer = SummaryWriter(log_dir=log_path + "/dev")

    train, dev, test, words = read_corpus(args.data)
    dev_, test_ = dev, test
    train = create_batches(train, args.batch_size)
    dev = create_batches(dev, args.batch_size)
    test = create_batches(test, args.batch_size)

    model = Model(words, args)
    if args.load:
        model.load_state_dict(torch.load(args.load))
    model.cuda()
    print(model)
    print("vocab size: {}".format(model.n_V))

    lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps**
                                                            1.5)
    if args.prune:
        # in place substituion of linear ops in SRU
        flop.make_hard_concrete(model.rnn, in_place=True)
        model.cuda()
        print("model after inserting hardconcrete:")
        print(model)
        hc_modules = flop.get_hardconcrete_modules(model)
        hc_parameters = [
            p for m in hc_modules for p in m.parameters() if p.requires_grad
        ]
        optimizer_hc = Adam(hc_parameters,
                            lr=lr * args.prune_lr,
                            weight_decay=0)
        num_hardconcrete_params = sum(x.numel() for x in hc_parameters)
        print("num of hardconcrete paramters: {}".format(
            num_hardconcrete_params))
        lambda_1 = nn.Parameter(torch.tensor(0.).cuda())
        lambda_2 = nn.Parameter(torch.tensor(0.).cuda())
        optimizer_max = Adam([lambda_1, lambda_2], lr=lr, weight_decay=0)
        optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr
        hc_linear_modules = flop.get_hardconcrete_linear_modules(model)
        num_prunable_params = sum(m.num_prunable_parameters()
                                  for m in hc_linear_modules)
        print("num of prunable paramters: {}".format(num_prunable_params))
    else:
        args.prune_start_epoch = args.max_epoch

    m_parameters = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and 'log_alpha' not in i[0]
    ]
    optimizer = Adam(m_parameters,
                     lr=lr * args.lr,
                     weight_decay=args.weight_decay)
    num_params = sum(x.numel() for x in m_parameters if x.requires_grad)
    print("num of parameters: {}".format(num_params))

    nbatch = 1
    niter = 1
    best_dev = 1e+8
    unroll_size = args.unroll_size
    batch_size = args.batch_size
    N = (len(train[0]) - 1) // unroll_size + 1
    criterion = nn.CrossEntropyLoss()

    model.zero_grad()
    if args.prune:
        optimizer_max.zero_grad()
        optimizer_hc.zero_grad()

    for epoch in range(args.max_epoch):
        start_time = time.time()
        model.train()
        total_loss = 0.0
        hidden = model.init_hidden(batch_size)
        start_prune = epoch >= args.prune_start_epoch

        for i in range(N):
            x = train[0][i * unroll_size:(i + 1) * unroll_size]
            y = train[1][i * unroll_size:(i + 1) * unroll_size].view(-1)
            hidden.detach_()

            # language model forward and backward
            output, hidden = model(x, hidden)
            loss = criterion(output, y)
            (loss / args.update_param_freq).backward()
            loss = loss.item()
            lagrangian_loss = 0
            target_sparsity = 0
            expected_sparsity = 0

            # add lagrangian loss (regularization) when pruning
            if start_prune:
                # compute target sparsity with (optionally) linear warmup
                target_sparsity = args.prune_sparsity
                if args.prune_warmup > 0:
                    niter_ = niter - args.prune_start_epoch * N
                    target_sparsity *= min(1.0, niter_ / args.prune_warmup)

                # compute expected model size and sparsity
                expected_size = sum(
                    m.num_parameters(train=True) for m in hc_linear_modules)
                expected_sparsity = 1.0 - expected_size / num_prunable_params

                # compute lagrangian loss
                lagrangian_loss = lambda_1 * (expected_sparsity - target_sparsity) + \
                                  lambda_2 * (expected_sparsity - target_sparsity)**2
                (lagrangian_loss / args.update_param_freq).backward()
                expected_sparsity = expected_sparsity.item()
                lagrangian_loss = lagrangian_loss.item()

            #  log training stats
            if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0:
                if args.prune:
                    train_writer.add_scalar('sparsity/expected_sparsity',
                                            expected_sparsity, niter)
                    train_writer.add_scalar('sparsity/target_sparsity',
                                            target_sparsity, niter)
                    train_writer.add_scalar('loss/lagrangian_loss',
                                            lagrangian_loss, niter)
                    train_writer.add_scalar('lambda/1', lambda_1.item(), niter)
                    train_writer.add_scalar('lambda/2', lambda_2.item(), niter)
                    if (niter - 1) % 3000 == 0:
                        for index, layer in enumerate(hc_modules):
                            train_writer.add_histogram(
                                'log_alpha/{}'.format(index),
                                layer.log_alpha,
                                niter,
                                bins='sqrt',
                            )
                sys.stderr.write("\r{:.4f} {:.2f} {:.2f}".format(
                    loss,
                    lagrangian_loss,
                    expected_sparsity,
                ))
                train_writer.add_scalar('loss/lm_loss', loss, niter)
                train_writer.add_scalar('loss/total_loss',
                                        loss + lagrangian_loss, niter)
                train_writer.add_scalar(
                    'parameter_norm',
                    calc_norm([x.data for x in m_parameters]), niter)
                train_writer.add_scalar(
                    'gradient_norm',
                    calc_norm(
                        [x.grad for x in m_parameters if x.grad is not None]),
                    niter)

            #  perform gradient decent every few number of backward()
            if nbatch % args.update_param_freq == 0:
                if args.clip_grad > 0:
                    torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad)
                optimizer.step()
                if start_prune:
                    optimizer_max.step()
                    optimizer_hc.step()
                #  clear gradient
                model.zero_grad()
                if args.prune:
                    optimizer_max.zero_grad()
                    optimizer_hc.zero_grad()
                niter += 1

            if nbatch % args.log_period == 0 or i == N - 1:
                elapsed_time = (time.time() - start_time) / 60.0
                dev_ppl, dev_loss = eval_model(model, dev)
                dev_writer.add_scalar('loss/lm_loss', dev_loss, niter)
                dev_writer.add_scalar('bpc', np.log2(dev_ppl), niter)
                sparsity = 0
                if args.prune:
                    pruned_size = sum(
                        m.num_parameters(train=False)
                        for m in hc_linear_modules)
                    sparsity = 1.0 - pruned_size / num_prunable_params
                    dev_writer.add_scalar('sparsity/hard_sparsity', sparsity,
                                          niter)
                    dev_writer.add_scalar('model_size/total_prunable',
                                          num_prunable_params, niter)
                    dev_writer.add_scalar('model_size/current_prunable',
                                          pruned_size, niter)
                    dev_writer.add_scalar('model_size/total', num_params,
                                          niter)
                    dev_writer.add_scalar(
                        'model_size/current',
                        num_params - num_prunable_params + pruned_size, niter)
                sys.stdout.write(
                    "\rIter={}  lr={:.5f}  train_loss={:.4f}  dev_loss={:.4f}"
                    "  dev_bpc={:.2f}  sparsity={:.2f}\teta={:.1f}m\t[{:.1f}m]\n"
                    .format(niter, optimizer.param_groups[0]['lr'], loss,
                            dev_loss, np.log2(dev_ppl), sparsity,
                            elapsed_time * N / (i + 1), elapsed_time))
                if dev_ppl < best_dev:
                    if (not args.prune
                        ) or sparsity > args.prune_sparsity - 0.02:
                        best_dev = dev_ppl
                        checkpoint = copy_model(model)
                sys.stdout.write("\n")
                sys.stdout.flush()

            nbatch += 1
            if args.noam:
                lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5))
                optimizer.param_groups[0]['lr'] = lr * args.lr / (args.n_d**
                                                                  0.5)
            if args.noam and start_prune:
                niter_ = niter - args.prune_start_epoch * N
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr / (
                    args.n_d**0.5)
                optimizer_hc.param_groups[0]['lr'] = lr * args.lr / (args.n_d**
                                                                     0.5)

        if args.save and (epoch + 1) % 10 == 0:
            torch.save(
                copy_model(model),
                "{}.{}.{:.3f}.pt".format(args.save, epoch + 1, sparsity))

    train_writer.close()
    dev_writer.close()

    model.load_state_dict(checkpoint)
    model.cuda()
    dev = create_batches(dev_, 1)
    test = create_batches(test_, 1)
    dev_ppl, dev_loss = eval_model(model, dev)
    test_ppl, test_loss = eval_model(model, test)
    sys.stdout.write("dev_bpc={:.3f}  test_bpc={:.3f}\n".format(
        np.log2(dev_ppl), np.log2(test_ppl)))
예제 #3
0
def main(args):

    if args.local_rank == 0:
        log_path = "{}_{}".format(args.log, random.randint(1, 100))
        train_writer = SummaryWriter(log_dir=log_path + "/train")
        dev_writer = SummaryWriter(log_dir=log_path + "/dev")

    # set up distributed training
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl")
    set_seed(1234)
    args.n_gpu = 1
    args.device = device
    local_rank = args.local_rank

    corpus = get_lm_corpus(args.data, "wt103")
    n_token = args.n_token = len(corpus.vocab)
    args.eval_batch_size = args.eval_batch_size or args.batch_size
    args.eval_unroll_size = args.eval_unroll_size or args.unroll_size
    unroll_size = args.unroll_size
    eval_unroll_size = args.eval_unroll_size
    batch_size = args.batch_size
    eval_batch_size = args.eval_batch_size
    n_nodes = torch.cuda.device_count()
    train = corpus.get_distributed_iterator(
        "train",
        batch_size,
        unroll_size,
        n_nodes=n_nodes,
        rank=local_rank,
        device=device,
    )
    dev = corpus.get_iterator("valid",
                              eval_batch_size,
                              eval_unroll_size,
                              device=device)
    if local_rank == 0:
        print("vocab size: {}".format(n_token))

    model = Model(args)
    if args.load:
        model.load_state_dict(torch.load(args.load))
    lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps**
                                                            1.5)
    if args.prune:
        # in place substituion of linear ops in SRU
        flop.make_hard_concrete(model.rnn,
                                in_place=True,
                                init_mean=args.prune_init_mean)
        model.embedding_layer = HardConcreteAdaptiveEmbedding.from_module(
            model.embedding_layer, init_mean=args.prune_init_mean)
        model.output_layer = HardConcreteAdaptiveLogSoftmax.from_module(
            model.output_layer, init_mean=args.prune_init_mean)
        # tie weights again
        model.tie_weights()
        model.to(device)
        hc_modules = flop.get_hardconcrete_modules(
            model.rnn) + flop.get_hardconcrete_modules(model.embedding_layer)
        # print(len(flop.get_hardconcrete_modules(model)))
        # print(len(hc_modules))
        hc_parameters = [
            p for m in hc_modules for p in m.parameters() if p.requires_grad
        ]
        optimizer_hc = torch.optim.Adam(hc_parameters,
                                        lr=lr * args.prune_lr,
                                        weight_decay=0)

        lambda_1 = nn.Parameter(torch.tensor(0.0).cuda())
        lambda_2 = nn.Parameter(torch.tensor(0.0).cuda())
        optimizer_max = torch.optim.Adam([lambda_1, lambda_2],
                                         lr=lr,
                                         weight_decay=0)
        optimizer_max.param_groups[0]["lr"] = -lr * args.prune_lr
        hc_linear_modules = flop.get_hardconcrete_linear_modules(model) + [
            model.embedding_layer
        ]

        num_hardconcrete_params = sum(x.numel() for x in hc_parameters)
        num_prunable_params = sum(m.num_prunable_parameters()
                                  for m in hc_linear_modules)
        if local_rank == 0:
            print("num of hardconcrete paramters: {}".format(
                num_hardconcrete_params))
            print("num of prunable paramters: {}".format(num_prunable_params))
    else:
        model.to(device)
        args.prune_start_epoch = args.max_epoch

    m_parameters = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and "log_alpha" not in i[0]
    ]
    optimizer = torch.optim.Adam(m_parameters,
                                 lr=lr * args.lr,
                                 weight_decay=args.weight_decay)
    num_params = sum(x.numel() for x in m_parameters if x.requires_grad)

    model_ = model
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        dim=1,
        device_ids=[local_rank],
        output_device=local_rank,
    )

    nbatch = 1
    niter = 1
    best_dev = 1e8
    unroll_size = args.unroll_size
    batch_size = args.batch_size
    N = train.n_batch
    checkpoint = None
    if local_rank == 0:
        print(model)
        print("num of parameters: {}".format(num_params))
        print("num of mini-batches: {}".format(N))

    model.zero_grad()
    if args.prune:
        optimizer_max.zero_grad()
        optimizer_hc.zero_grad()

    for epoch in range(args.max_epoch):
        start_time = time.time()
        model.train()
        total_loss = 0.0
        hidden = model_.init_hidden(batch_size)
        start_prune = epoch >= args.prune_start_epoch
        i = 0

        for x, y, seq_len in train:
            i += 1
            hidden.detach_()

            # language model forward and backward
            loss, hidden = model(x, y, hidden)
            loss = loss.mean()
            (loss / args.update_param_freq).backward()
            loss = loss.item()
            lagrangian_loss = 0
            target_sparsity = 0
            expected_sparsity = 0

            # add lagrangian loss (regularization) when pruning
            if start_prune:
                # compute target sparsity with (optionally) linear warmup
                target_sparsity = args.prune_sparsity
                if args.prune_warmup > 0:
                    niter_ = niter - args.prune_start_epoch * N
                    target_sparsity *= min(1.0, niter_ / args.prune_warmup)

                # compute expected model size and sparsity
                expected_size = sum(
                    m.num_parameters(train=True) for m in hc_linear_modules)
                expected_sparsity = 1.0 - expected_size / num_prunable_params

                # compute lagrangian loss
                lagrangian_loss = (lambda_1 *
                                   (expected_sparsity - target_sparsity) +
                                   lambda_2 *
                                   (expected_sparsity - target_sparsity)**2)
                (lagrangian_loss / args.update_param_freq).backward()
                expected_sparsity = expected_sparsity.item()
                lagrangian_loss = lagrangian_loss.item()

            #  log training stats
            if (local_rank == 0 and (niter - 1) % 100 == 0
                    and nbatch % args.update_param_freq == 0):
                if args.prune:
                    train_writer.add_scalar("sparsity/expected_sparsity",
                                            expected_sparsity, niter)
                    train_writer.add_scalar("sparsity/target_sparsity",
                                            target_sparsity, niter)
                    train_writer.add_scalar("loss/lagrangian_loss",
                                            lagrangian_loss, niter)
                    train_writer.add_scalar("lambda/1", lambda_1.item(), niter)
                    train_writer.add_scalar("lambda/2", lambda_2.item(), niter)
                    if (nbatch - 1) % 3000 == 0:
                        for index, layer in enumerate(hc_modules):
                            train_writer.add_histogram(
                                "log_alpha/{}".format(index),
                                layer.log_alpha,
                                niter,
                                bins="sqrt",
                            )
                sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format(
                    math.exp(loss),
                    lagrangian_loss,
                    expected_sparsity,
                    (time.time() - start_time) / 60.0 / (i + 1) * (N - i - 1),
                ))
                train_writer.add_scalar("loss/ppl", math.exp(loss), niter)
                train_writer.add_scalar("loss/lm_loss", loss, niter)
                train_writer.add_scalar("loss/total_loss",
                                        loss + lagrangian_loss, niter)
                train_writer.add_scalar(
                    "parameter_norm",
                    calc_norm([x.data for x in m_parameters]), niter)
                train_writer.add_scalar(
                    "gradient_norm",
                    calc_norm(
                        [x.grad for x in m_parameters if x.grad is not None]),
                    niter,
                )

            #  perform gradient decent every few number of backward()
            if nbatch % args.update_param_freq == 0:
                if args.clip_grad > 0:
                    torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad)
                optimizer.step()
                if start_prune:
                    optimizer_max.step()
                    optimizer_hc.step()
                #  clear gradient
                model.zero_grad()
                if args.prune:
                    optimizer_max.zero_grad()
                    optimizer_hc.zero_grad()
                niter += 1

            if local_rank == 0 and (nbatch % args.log_period == 0 or i == N):
                elapsed_time = (time.time() - start_time) / 60.0
                dev_ppl, dev_loss = eval_model(model_, dev)
                dev_writer.add_scalar("loss/lm_loss", dev_loss, niter)
                dev_writer.add_scalar("loss/ppl", dev_ppl, niter)
                dev_writer.add_scalar("ppl", dev_ppl, niter)
                sparsity = 0
                if args.prune:
                    pruned_size = sum(
                        m.num_parameters(train=False)
                        for m in hc_linear_modules)
                    sparsity = 1.0 - pruned_size / num_prunable_params
                    dev_writer.add_scalar("sparsity/hard_sparsity", sparsity,
                                          niter)
                    dev_writer.add_scalar("model_size/total_prunable",
                                          num_prunable_params, niter)
                    dev_writer.add_scalar("model_size/current_prunable",
                                          pruned_size, niter)
                    dev_writer.add_scalar("model_size/total", num_params,
                                          niter)
                    dev_writer.add_scalar(
                        "model_size/current",
                        num_params - num_prunable_params + pruned_size,
                        niter,
                    )
                    dev_writer.add_scalar(
                        "model_size/current_embedding",
                        model_.embedding_layer.num_parameters(train=False),
                        niter,
                    )
                    dev_writer.add_scalar(
                        "model_size/current_output_layer",
                        model_.output_layer.num_parameters(train=False),
                        niter,
                    )
                sys.stdout.write(
                    "\rnum_batches={}  lr={:.5f}  train_loss={:.4f}  dev_loss={:.4f}"
                    "  dev_bpc={:.2f}  sparsity={:.2f}\t[{:.1f}m]\n".format(
                        nbatch,
                        optimizer.param_groups[0]["lr"],
                        loss,
                        dev_loss,
                        dev_ppl,
                        sparsity,
                        elapsed_time,
                    ))
                if dev_ppl < best_dev:
                    if (not args.prune
                        ) or sparsity > args.prune_sparsity - 0.02:
                        best_dev = dev_ppl
                        checkpoint = copy_model(model_)
                sys.stdout.write("\n")
                sys.stdout.flush()

            nbatch += 1
            if args.noam:
                lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5))
                optimizer.param_groups[0]["lr"] = lr * args.lr / (args.n_d**
                                                                  0.5)
            if args.noam and start_prune:
                niter_ = niter - args.prune_start_epoch * N
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                optimizer_max.param_groups[0]["lr"] = (-lr * args.prune_lr /
                                                       (args.n_d**0.5))
                optimizer_hc.param_groups[0]["lr"] = lr * args.lr / (args.n_d**
                                                                     0.5)

        if local_rank == 0 and args.save and checkpoint is not None:
            torch.save(checkpoint, "{}.pt".format(args.save, ))

    if local_rank == 0:
        train_writer.close()
        dev_writer.close()

        if checkpoint is not None:
            model_.load_state_dict(checkpoint)
            model_.to(device)
        # dev = create_batches(dev_, 1)
        # test = create_batches(test_, 1)
        test = corpus.get_iterator("test",
                                   eval_batch_size,
                                   eval_unroll_size,
                                   device=device)
        dev_ppl, dev_loss = eval_model(model_, dev)
        test_ppl, test_loss = eval_model(model_, test)
        sys.stdout.write("dev_ppl={:.3f}  test_ppl={:.3f}\n".format(
            dev_ppl, test_ppl))
예제 #4
0
파일: train.py 프로젝트: sasaadi/flop
def main(args):
    log_path = "{}_{}".format(args.log, random.randint(1, 100))
    train_writer = SummaryWriter(log_dir=log_path + "/train")
    dev_writer = SummaryWriter(log_dir=log_path + "/dev")

    device = torch.device('cuda')
    corpus = get_lm_corpus(args.data, 'wt103')
    n_token = args.n_token = len(corpus.vocab)
    args.eval_batch_size = args.eval_batch_size or args.batch_size
    args.eval_unroll_size = args.eval_unroll_size or args.unroll_size
    train = corpus.get_iterator('train',
                                args.batch_size,
                                args.unroll_size,
                                device=device)
    dev = corpus.get_iterator('valid',
                              args.eval_batch_size,
                              args.eval_unroll_size,
                              device=device)
    test = corpus.get_iterator('test',
                               args.eval_batch_size,
                               args.eval_unroll_size,
                               device=device)
    print("vocab size: {}".format(n_token))

    model = Model(args)
    if args.load:
        model.load_state_dict(torch.load(args.load))
    model.cuda()
    print(model)
    if torch.cuda.device_count() > 1:
        para_model = torch.nn.DataParallel(model, dim=1)  #, output_device=1)
    else:
        para_model = model
    lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps**
                                                            1.5)
    if args.prune:
        # in place substituion of linear ops in SRU
        flop.make_hard_concrete(model.rnn,
                                in_place=True,
                                init_mean=args.prune_init_mean)
        model.embedding_layer = HardConcreteAdaptiveEmbedding.from_module(
            model.embedding_layer, init_mean=args.prune_init_mean)
        model.output_layer = HardConcreteAdaptiveLogSoftmax.from_module(
            model.output_layer, init_mean=args.prune_init_mean)
        # tie weights again
        model.tie_weights()
        model.cuda()
        print("model after inserting hardconcrete:")
        print(model)
        hc_modules = flop.get_hardconcrete_modules(
            model.rnn) + flop.get_hardconcrete_modules(model.embedding_layer)
        print(len(flop.get_hardconcrete_modules(model)))
        print(len(hc_modules))
        hc_parameters = [
            p for m in hc_modules for p in m.parameters() if p.requires_grad
        ]
        optimizer_hc = RAdam(hc_parameters,
                             lr=lr * args.prune_lr,
                             weight_decay=0)
        num_hardconcrete_params = sum(x.numel() for x in hc_parameters)
        print("num of hardconcrete paramters: {}".format(
            num_hardconcrete_params))
        lambda_1 = nn.Parameter(torch.tensor(0.).cuda())
        lambda_2 = nn.Parameter(torch.tensor(0.).cuda())
        optimizer_max = RAdam([lambda_1, lambda_2], lr=lr, weight_decay=0)
        optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr
        hc_linear_modules = flop.get_hardconcrete_linear_modules(model) + \
                [model.embedding_layer]
        num_prunable_params = sum(m.num_prunable_parameters()
                                  for m in hc_linear_modules)
        print("num of prunable paramters: {}".format(num_prunable_params))
    else:
        args.prune_start_epoch = args.max_epoch

    m_parameters = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and 'log_alpha' not in i[0]
    ]
    optimizer = RAdam(m_parameters,
                      lr=lr * args.lr,
                      weight_decay=args.weight_decay)
    num_params = sum(x.numel() for x in m_parameters if x.requires_grad)
    print("num of parameters: {}".format(num_params))

    nbatch = 1
    niter = 1
    best_dev = 1e+8
    unroll_size = args.unroll_size
    batch_size = args.batch_size
    N = train.n_batch
    checkpoint = None
    print("num of mini-batches: {}".format(N))

    model.zero_grad()
    if args.prune:
        optimizer_max.zero_grad()
        optimizer_hc.zero_grad()

    for epoch in range(args.max_epoch):
        start_time = time.time()
        model.train()
        total_loss = 0.0
        hidden = model.init_hidden(batch_size)
        start_prune = epoch >= args.prune_start_epoch
        i = 0

        for x, y, seq_len in train:
            i += 1
            hidden.detach_()

            # language model forward and backward
            loss, hidden = para_model(x, y, hidden)
            loss = loss.mean()
            (loss / args.update_param_freq).backward()
            loss = loss.item()
            lagrangian_loss = 0
            target_sparsity = 0
            expected_sparsity = 0

            # add lagrangian loss (regularization) when pruning
            if start_prune:
                # compute target sparsity with (optionally) linear warmup
                target_sparsity = args.prune_sparsity
                if args.prune_warmup > 0:
                    niter_ = niter - args.prune_start_epoch * N
                    target_sparsity *= min(1.0, niter_ / args.prune_warmup)

                # compute expected model size and sparsity
                expected_size = sum(
                    m.num_parameters(train=True) for m in hc_linear_modules)
                expected_sparsity = 1.0 - expected_size / num_prunable_params

                # compute lagrangian loss
                lagrangian_loss = lambda_1 * (expected_sparsity - target_sparsity) + \
                                  lambda_2 * (expected_sparsity - target_sparsity)**2
                (lagrangian_loss / args.update_param_freq).backward()
                expected_sparsity = expected_sparsity.item()
                lagrangian_loss = lagrangian_loss.item()

            #  log training stats
            if (niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0:
                if args.prune:
                    train_writer.add_scalar('sparsity/expected_sparsity',
                                            expected_sparsity, niter)
                    train_writer.add_scalar('sparsity/target_sparsity',
                                            target_sparsity, niter)
                    train_writer.add_scalar('loss/lagrangian_loss',
                                            lagrangian_loss, niter)
                    train_writer.add_scalar('lambda/1', lambda_1.item(), niter)
                    train_writer.add_scalar('lambda/2', lambda_2.item(), niter)
                    if (nbatch - 1) % 3000 == 0:
                        for index, layer in enumerate(hc_modules):
                            train_writer.add_histogram(
                                'log_alpha/{}'.format(index),
                                layer.log_alpha,
                                niter,
                                bins='sqrt',
                            )
                sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format(
                    math.exp(loss),
                    lagrangian_loss,
                    expected_sparsity,
                    (time.time() - start_time) / 60.0 / (i + 1) * (N - i - 1),
                ))
                train_writer.add_scalar('loss/ppl', math.exp(loss), niter)
                train_writer.add_scalar('loss/lm_loss', loss, niter)
                train_writer.add_scalar('loss/total_loss',
                                        loss + lagrangian_loss, niter)
                train_writer.add_scalar(
                    'parameter_norm',
                    calc_norm([x.data for x in m_parameters]), niter)
                train_writer.add_scalar(
                    'gradient_norm',
                    calc_norm(
                        [x.grad for x in m_parameters if x.grad is not None]),
                    niter)

            #  perform gradient decent every few number of backward()
            if nbatch % args.update_param_freq == 0:
                if args.clip_grad > 0:
                    torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad)
                optimizer.step()
                if start_prune:
                    optimizer_max.step()
                    optimizer_hc.step()
                #  clear gradient
                model.zero_grad()
                if args.prune:
                    optimizer_max.zero_grad()
                    optimizer_hc.zero_grad()
                niter += 1

            if nbatch % args.log_period == 0 or i == N:
                elapsed_time = (time.time() - start_time) / 60.0
                dev_ppl, dev_loss = eval_model(model, dev)
                dev_writer.add_scalar('loss/lm_loss', dev_loss, niter)
                dev_writer.add_scalar('loss/ppl', dev_ppl, niter)
                dev_writer.add_scalar('ppl', dev_ppl, niter)
                sparsity = 0
                if args.prune:
                    pruned_size = sum(
                        m.num_parameters(train=False)
                        for m in hc_linear_modules)
                    sparsity = 1.0 - pruned_size / num_prunable_params
                    dev_writer.add_scalar('sparsity/hard_sparsity', sparsity,
                                          niter)
                    dev_writer.add_scalar('model_size/total_prunable',
                                          num_prunable_params, niter)
                    dev_writer.add_scalar('model_size/current_prunable',
                                          pruned_size, niter)
                    dev_writer.add_scalar('model_size/total', num_params,
                                          niter)
                    dev_writer.add_scalar(
                        'model_size/current',
                        num_params - num_prunable_params + pruned_size, niter)
                    dev_writer.add_scalar(
                        'model_size/current_embedding',
                        model.embedding_layer.num_parameters(train=False),
                        niter)
                    dev_writer.add_scalar(
                        'model_size/current_output_layer',
                        model.output_layer.num_parameters(train=False), niter)
                sys.stdout.write(
                    "\rnum_batches={}  lr={:.5f}  train_loss={:.4f}  dev_loss={:.4f}"
                    "  dev_bpc={:.2f}  sparsity={:.2f}\t[{:.1f}m]\n".format(
                        nbatch, optimizer.param_groups[0]['lr'], loss,
                        dev_loss, dev_ppl, sparsity, elapsed_time))
                if dev_ppl < best_dev:
                    best_dev = dev_ppl
                    checkpoint = copy_model(model)
                sys.stdout.write("\n")
                sys.stdout.flush()

            nbatch += 1
            if args.noam:
                lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5))
                optimizer.param_groups[0]['lr'] = lr * args.lr / (args.n_d**
                                                                  0.5)
            if args.noam and start_prune:
                niter_ = niter - args.prune_start_epoch * N
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr / (
                    args.n_d**0.5)
                optimizer_hc.param_groups[0]['lr'] = lr * args.lr / (args.n_d**
                                                                     0.5)

        if args.save and (epoch + 1) % 5 == 0:
            torch.save(
                checkpoint,
                "{}.{}.{}.pt".format(args.save, epoch + 1,
                                     int(dev_ppl)
                                     #sparsity
                                     ))

    train_writer.close()
    dev_writer.close()

    if checkpoint is not None:
        model.load_state_dict(checkpoint)
        model.cuda()
    #dev = create_batches(dev_, 1)
    #test = create_batches(test_, 1)
    dev_ppl, dev_loss = eval_model(model, dev)
    test_ppl, test_loss = eval_model(model, test)
    sys.stdout.write("dev_ppl={:.3f}  test_ppl={:.3f}\n".format(
        dev_ppl, test_ppl))