Пример #1
0
def main(args):

    if args.local_rank == 0:
        log_path = "{}_{}".format(args.log, random.randint(1, 100))
        train_writer = SummaryWriter(log_dir=log_path + "/train")
        dev_writer = SummaryWriter(log_dir=log_path + "/dev")

    # set up distributed training
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl")
    set_seed(1234)
    args.n_gpu = 1
    args.device = device
    local_rank = args.local_rank

    corpus = get_lm_corpus(args.data, 'wt103')
    n_token = args.n_token = len(corpus.vocab)
    args.eval_batch_size = args.eval_batch_size or args.batch_size
    args.eval_unroll_size = args.eval_unroll_size or args.unroll_size
    unroll_size = args.unroll_size
    eval_unroll_size = args.eval_unroll_size
    batch_size = args.batch_size
    eval_batch_size = args.eval_batch_size
    n_nodes = torch.cuda.device_count()
    train = corpus.get_distributed_iterator('train',
                                            batch_size,
                                            unroll_size,
                                            n_nodes=n_nodes,
                                            rank=local_rank,
                                            device=device)
    dev = corpus.get_iterator('valid',
                              eval_batch_size,
                              eval_unroll_size,
                              device=device)
    if local_rank == 0:
        print("vocab size: {}".format(n_token))

    model = Model(args)
    if args.load:
        model.load_state_dict(torch.load(args.load))
    lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps**
                                                            1.5)
    if args.prune:
        # in place substituion of linear ops in SRU
        flop.make_hard_concrete(model.rnn,
                                in_place=True,
                                init_mean=args.prune_init_mean)
        model.embedding_layer = HardConcreteAdaptiveEmbedding.from_module(
            model.embedding_layer, init_mean=args.prune_init_mean)
        model.output_layer = HardConcreteAdaptiveLogSoftmax.from_module(
            model.output_layer, init_mean=args.prune_init_mean)
        # tie weights again
        model.tie_weights()
        model.to(device)
        hc_modules = flop.get_hardconcrete_modules(
            model.rnn) + flop.get_hardconcrete_modules(model.embedding_layer)
        #print(len(flop.get_hardconcrete_modules(model)))
        #print(len(hc_modules))
        hc_parameters = [
            p for m in hc_modules for p in m.parameters() if p.requires_grad
        ]
        optimizer_hc = torch.optim.Adam(hc_parameters,
                                        lr=lr * args.prune_lr,
                                        weight_decay=0)

        lambda_1 = nn.Parameter(torch.tensor(0.).cuda())
        lambda_2 = nn.Parameter(torch.tensor(0.).cuda())
        optimizer_max = torch.optim.Adam([lambda_1, lambda_2],
                                         lr=lr,
                                         weight_decay=0)
        optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr
        hc_linear_modules = flop.get_hardconcrete_linear_modules(model) + \
                [model.embedding_layer]

        num_hardconcrete_params = sum(x.numel() for x in hc_parameters)
        num_prunable_params = sum(m.num_prunable_parameters()
                                  for m in hc_linear_modules)
        if local_rank == 0:
            print("num of hardconcrete paramters: {}".format(
                num_hardconcrete_params))
            print("num of prunable paramters: {}".format(num_prunable_params))
    else:
        model.to(device)
        args.prune_start_epoch = args.max_epoch

    m_parameters = [
        i[1] for i in model.named_parameters()
        if i[1].requires_grad and 'log_alpha' not in i[0]
    ]
    optimizer = torch.optim.Adam(m_parameters,
                                 lr=lr * args.lr,
                                 weight_decay=args.weight_decay)
    num_params = sum(x.numel() for x in m_parameters if x.requires_grad)

    model_ = model
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        dim=1,
        device_ids=[local_rank],
        output_device=local_rank,
    )

    nbatch = 1
    niter = 1
    best_dev = 1e+8
    unroll_size = args.unroll_size
    batch_size = args.batch_size
    N = train.n_batch
    checkpoint = None
    if local_rank == 0:
        print(model)
        print("num of parameters: {}".format(num_params))
        print("num of mini-batches: {}".format(N))

    model.zero_grad()
    if args.prune:
        optimizer_max.zero_grad()
        optimizer_hc.zero_grad()

    for epoch in range(args.max_epoch):
        start_time = time.time()
        model.train()
        total_loss = 0.0
        hidden = model_.init_hidden(batch_size)
        start_prune = epoch >= args.prune_start_epoch
        i = 0

        for x, y, seq_len in train:
            i += 1
            hidden.detach_()

            # language model forward and backward
            loss, hidden = model(x, y, hidden)
            loss = loss.mean()
            (loss / args.update_param_freq).backward()
            loss = loss.item()
            lagrangian_loss = 0
            target_sparsity = 0
            expected_sparsity = 0

            # add lagrangian loss (regularization) when pruning
            if start_prune:
                # compute target sparsity with (optionally) linear warmup
                target_sparsity = args.prune_sparsity
                if args.prune_warmup > 0:
                    niter_ = niter - args.prune_start_epoch * N
                    target_sparsity *= min(1.0, niter_ / args.prune_warmup)

                # compute expected model size and sparsity
                expected_size = sum(
                    m.num_parameters(train=True) for m in hc_linear_modules)
                expected_sparsity = 1.0 - expected_size / num_prunable_params

                # compute lagrangian loss
                lagrangian_loss = lambda_1 * (expected_sparsity - target_sparsity) + \
                                  lambda_2 * (expected_sparsity - target_sparsity)**2 * args.prune_beta
                (lagrangian_loss / args.update_param_freq).backward()
                expected_sparsity = expected_sparsity.item()
                lagrangian_loss = lagrangian_loss.item()

            #  log training stats
            if local_rank == 0 and (
                    niter -
                    1) % 100 == 0 and nbatch % args.update_param_freq == 0:
                if args.prune:
                    train_writer.add_scalar('sparsity/expected_sparsity',
                                            expected_sparsity, niter)
                    train_writer.add_scalar('sparsity/target_sparsity',
                                            target_sparsity, niter)
                    train_writer.add_scalar('loss/lagrangian_loss',
                                            lagrangian_loss, niter)
                    train_writer.add_scalar('lambda/1', lambda_1.item(), niter)
                    train_writer.add_scalar('lambda/2', lambda_2.item(), niter)
                    if (nbatch - 1) % 3000 == 0:
                        for index, layer in enumerate(hc_modules):
                            train_writer.add_histogram(
                                'log_alpha/{}'.format(index),
                                layer.log_alpha,
                                niter,
                                bins='sqrt',
                            )
                sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format(
                    math.exp(loss),
                    lagrangian_loss,
                    expected_sparsity,
                    (time.time() - start_time) / 60.0 / (i + 1) * (N - i - 1),
                ))
                train_writer.add_scalar('loss/ppl', math.exp(loss), niter)
                train_writer.add_scalar('loss/lm_loss', loss, niter)
                train_writer.add_scalar('loss/total_loss',
                                        loss + lagrangian_loss, niter)
                train_writer.add_scalar(
                    'parameter_norm',
                    calc_norm([x.data for x in m_parameters]), niter)
                train_writer.add_scalar(
                    'gradient_norm',
                    calc_norm(
                        [x.grad for x in m_parameters if x.grad is not None]),
                    niter)

            #  perform gradient decent every few number of backward()
            if nbatch % args.update_param_freq == 0:
                if args.clip_grad > 0:
                    torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad)
                optimizer.step()
                if start_prune:
                    optimizer_max.step()
                    optimizer_hc.step()
                #  clear gradient
                model.zero_grad()
                if args.prune:
                    optimizer_max.zero_grad()
                    optimizer_hc.zero_grad()
                niter += 1

            if local_rank == 0 and (nbatch % args.log_period == 0 or i == N):
                elapsed_time = (time.time() - start_time) / 60.0
                dev_ppl, dev_loss = eval_model(model_, dev)
                dev_writer.add_scalar('loss/lm_loss', dev_loss, niter)
                dev_writer.add_scalar('loss/ppl', dev_ppl, niter)
                dev_writer.add_scalar('ppl', dev_ppl, niter)
                sparsity = 0
                if args.prune:
                    pruned_size = sum(
                        m.num_parameters(train=False)
                        for m in hc_linear_modules)
                    sparsity = 1.0 - pruned_size / num_prunable_params
                    dev_writer.add_scalar('sparsity/hard_sparsity', sparsity,
                                          niter)
                    dev_writer.add_scalar('model_size/total_prunable',
                                          num_prunable_params, niter)
                    dev_writer.add_scalar('model_size/current_prunable',
                                          pruned_size, niter)
                    dev_writer.add_scalar('model_size/total', num_params,
                                          niter)
                    dev_writer.add_scalar(
                        'model_size/current',
                        num_params - num_prunable_params + pruned_size, niter)
                    dev_writer.add_scalar(
                        'model_size/current_embedding',
                        model_.embedding_layer.num_parameters(train=False),
                        niter)
                    dev_writer.add_scalar(
                        'model_size/current_output_layer',
                        model_.output_layer.num_parameters(train=False), niter)
                sys.stdout.write(
                    "\rnum_batches={}  lr={:.5f}  train_loss={:.4f}  dev_loss={:.4f}"
                    "  dev_bpc={:.2f}  sparsity={:.2f}\t[{:.1f}m]\n".format(
                        nbatch, optimizer.param_groups[0]['lr'], loss,
                        dev_loss, dev_ppl, sparsity, elapsed_time))
                if dev_ppl < best_dev:
                    if (not args.prune
                        ) or sparsity > args.prune_sparsity - 0.005:
                        best_dev = dev_ppl
                        checkpoint = copy_model(model_)
                sys.stdout.write("\n")
                sys.stdout.flush()

            nbatch += 1
            if args.noam:
                lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5))
                optimizer.param_groups[0]['lr'] = lr * args.lr / (args.n_d**
                                                                  0.5)
            if args.noam and start_prune:
                niter_ = niter - args.prune_start_epoch * N
                lr = min(1.0 / (niter_**0.5),
                         niter_ / (args.warmup_steps**1.5))
                optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr / (
                    args.n_d**0.5)
                optimizer_hc.param_groups[0]['lr'] = lr * args.lr / (args.n_d**
                                                                     0.5)

        if local_rank == 0 and args.save and checkpoint is not None:
            torch.save(checkpoint, "{}.pt".format(args.save, ))

    if local_rank == 0:
        train_writer.close()
        dev_writer.close()

        if checkpoint is not None:
            model_.load_state_dict(checkpoint)
            model_.to(device)
        #dev = create_batches(dev_, 1)
        #test = create_batches(test_, 1)
        test = corpus.get_iterator('test',
                                   eval_batch_size,
                                   eval_unroll_size,
                                   device=device)
        dev_ppl, dev_loss = eval_model(model_, dev)
        test_ppl, test_loss = eval_model(model_, test)
        sys.stdout.write("dev_ppl={:.3f}  test_ppl={:.3f}\n".format(
            dev_ppl, test_ppl))
Пример #2
0
    if not args.cuda:
        print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
        args.fp16 = False
    else:
        try:
            from apex.fp16_utils import FP16_Optimizer
        except:
            print('WARNING: apex not installed, ignoring --fp16 option')
            args.fp16 = False

device = torch.device('cuda' if args.cuda else 'cpu')

###############################################################################
# Load data
###############################################################################
corpus = get_lm_corpus(args.data, args.dataset)
ntokens = len(corpus.vocab)
args.n_token = ntokens

eval_batch_size = 10
tr_iter = corpus.get_iterator('train',
                              args.batch_size,
                              args.tgt_len,
                              device=device,
                              ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid',
                              eval_batch_size,
                              args.eval_tgt_len,
                              device=device,
                              ext_len=args.ext_len)
te_iter = corpus.get_iterator('test',