def main(args): if args.local_rank == 0: log_path = "{}_{}".format(args.log, random.randint(1, 100)) train_writer = SummaryWriter(log_dir=log_path + "/train") dev_writer = SummaryWriter(log_dir=log_path + "/dev") # set up distributed training torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend="nccl") set_seed(1234) args.n_gpu = 1 args.device = device local_rank = args.local_rank corpus = get_lm_corpus(args.data, 'wt103') n_token = args.n_token = len(corpus.vocab) args.eval_batch_size = args.eval_batch_size or args.batch_size args.eval_unroll_size = args.eval_unroll_size or args.unroll_size unroll_size = args.unroll_size eval_unroll_size = args.eval_unroll_size batch_size = args.batch_size eval_batch_size = args.eval_batch_size n_nodes = torch.cuda.device_count() train = corpus.get_distributed_iterator('train', batch_size, unroll_size, n_nodes=n_nodes, rank=local_rank, device=device) dev = corpus.get_iterator('valid', eval_batch_size, eval_unroll_size, device=device) if local_rank == 0: print("vocab size: {}".format(n_token)) model = Model(args) if args.load: model.load_state_dict(torch.load(args.load)) lr = 1.0 if not args.noam else 1.0 / (args.n_d**0.5) / (args.warmup_steps** 1.5) if args.prune: # in place substituion of linear ops in SRU flop.make_hard_concrete(model.rnn, in_place=True, init_mean=args.prune_init_mean) model.embedding_layer = HardConcreteAdaptiveEmbedding.from_module( model.embedding_layer, init_mean=args.prune_init_mean) model.output_layer = HardConcreteAdaptiveLogSoftmax.from_module( model.output_layer, init_mean=args.prune_init_mean) # tie weights again model.tie_weights() model.to(device) hc_modules = flop.get_hardconcrete_modules( model.rnn) + flop.get_hardconcrete_modules(model.embedding_layer) #print(len(flop.get_hardconcrete_modules(model))) #print(len(hc_modules)) hc_parameters = [ p for m in hc_modules for p in m.parameters() if p.requires_grad ] optimizer_hc = torch.optim.Adam(hc_parameters, lr=lr * args.prune_lr, weight_decay=0) lambda_1 = nn.Parameter(torch.tensor(0.).cuda()) lambda_2 = nn.Parameter(torch.tensor(0.).cuda()) optimizer_max = torch.optim.Adam([lambda_1, lambda_2], lr=lr, weight_decay=0) optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr hc_linear_modules = flop.get_hardconcrete_linear_modules(model) + \ [model.embedding_layer] num_hardconcrete_params = sum(x.numel() for x in hc_parameters) num_prunable_params = sum(m.num_prunable_parameters() for m in hc_linear_modules) if local_rank == 0: print("num of hardconcrete paramters: {}".format( num_hardconcrete_params)) print("num of prunable paramters: {}".format(num_prunable_params)) else: model.to(device) args.prune_start_epoch = args.max_epoch m_parameters = [ i[1] for i in model.named_parameters() if i[1].requires_grad and 'log_alpha' not in i[0] ] optimizer = torch.optim.Adam(m_parameters, lr=lr * args.lr, weight_decay=args.weight_decay) num_params = sum(x.numel() for x in m_parameters if x.requires_grad) model_ = model model = torch.nn.parallel.DistributedDataParallel( model, dim=1, device_ids=[local_rank], output_device=local_rank, ) nbatch = 1 niter = 1 best_dev = 1e+8 unroll_size = args.unroll_size batch_size = args.batch_size N = train.n_batch checkpoint = None if local_rank == 0: print(model) print("num of parameters: {}".format(num_params)) print("num of mini-batches: {}".format(N)) model.zero_grad() if args.prune: optimizer_max.zero_grad() optimizer_hc.zero_grad() for epoch in range(args.max_epoch): start_time = time.time() model.train() total_loss = 0.0 hidden = model_.init_hidden(batch_size) start_prune = epoch >= args.prune_start_epoch i = 0 for x, y, seq_len in train: i += 1 hidden.detach_() # language model forward and backward loss, hidden = model(x, y, hidden) loss = loss.mean() (loss / args.update_param_freq).backward() loss = loss.item() lagrangian_loss = 0 target_sparsity = 0 expected_sparsity = 0 # add lagrangian loss (regularization) when pruning if start_prune: # compute target sparsity with (optionally) linear warmup target_sparsity = args.prune_sparsity if args.prune_warmup > 0: niter_ = niter - args.prune_start_epoch * N target_sparsity *= min(1.0, niter_ / args.prune_warmup) # compute expected model size and sparsity expected_size = sum( m.num_parameters(train=True) for m in hc_linear_modules) expected_sparsity = 1.0 - expected_size / num_prunable_params # compute lagrangian loss lagrangian_loss = lambda_1 * (expected_sparsity - target_sparsity) + \ lambda_2 * (expected_sparsity - target_sparsity)**2 * args.prune_beta (lagrangian_loss / args.update_param_freq).backward() expected_sparsity = expected_sparsity.item() lagrangian_loss = lagrangian_loss.item() # log training stats if local_rank == 0 and ( niter - 1) % 100 == 0 and nbatch % args.update_param_freq == 0: if args.prune: train_writer.add_scalar('sparsity/expected_sparsity', expected_sparsity, niter) train_writer.add_scalar('sparsity/target_sparsity', target_sparsity, niter) train_writer.add_scalar('loss/lagrangian_loss', lagrangian_loss, niter) train_writer.add_scalar('lambda/1', lambda_1.item(), niter) train_writer.add_scalar('lambda/2', lambda_2.item(), niter) if (nbatch - 1) % 3000 == 0: for index, layer in enumerate(hc_modules): train_writer.add_histogram( 'log_alpha/{}'.format(index), layer.log_alpha, niter, bins='sqrt', ) sys.stderr.write("\r{:.4f} {:.2f} {:.2f} eta={:.1f}m".format( math.exp(loss), lagrangian_loss, expected_sparsity, (time.time() - start_time) / 60.0 / (i + 1) * (N - i - 1), )) train_writer.add_scalar('loss/ppl', math.exp(loss), niter) train_writer.add_scalar('loss/lm_loss', loss, niter) train_writer.add_scalar('loss/total_loss', loss + lagrangian_loss, niter) train_writer.add_scalar( 'parameter_norm', calc_norm([x.data for x in m_parameters]), niter) train_writer.add_scalar( 'gradient_norm', calc_norm( [x.grad for x in m_parameters if x.grad is not None]), niter) # perform gradient decent every few number of backward() if nbatch % args.update_param_freq == 0: if args.clip_grad > 0: torch.nn.utils.clip_grad_norm(m_parameters, args.clip_grad) optimizer.step() if start_prune: optimizer_max.step() optimizer_hc.step() # clear gradient model.zero_grad() if args.prune: optimizer_max.zero_grad() optimizer_hc.zero_grad() niter += 1 if local_rank == 0 and (nbatch % args.log_period == 0 or i == N): elapsed_time = (time.time() - start_time) / 60.0 dev_ppl, dev_loss = eval_model(model_, dev) dev_writer.add_scalar('loss/lm_loss', dev_loss, niter) dev_writer.add_scalar('loss/ppl', dev_ppl, niter) dev_writer.add_scalar('ppl', dev_ppl, niter) sparsity = 0 if args.prune: pruned_size = sum( m.num_parameters(train=False) for m in hc_linear_modules) sparsity = 1.0 - pruned_size / num_prunable_params dev_writer.add_scalar('sparsity/hard_sparsity', sparsity, niter) dev_writer.add_scalar('model_size/total_prunable', num_prunable_params, niter) dev_writer.add_scalar('model_size/current_prunable', pruned_size, niter) dev_writer.add_scalar('model_size/total', num_params, niter) dev_writer.add_scalar( 'model_size/current', num_params - num_prunable_params + pruned_size, niter) dev_writer.add_scalar( 'model_size/current_embedding', model_.embedding_layer.num_parameters(train=False), niter) dev_writer.add_scalar( 'model_size/current_output_layer', model_.output_layer.num_parameters(train=False), niter) sys.stdout.write( "\rnum_batches={} lr={:.5f} train_loss={:.4f} dev_loss={:.4f}" " dev_bpc={:.2f} sparsity={:.2f}\t[{:.1f}m]\n".format( nbatch, optimizer.param_groups[0]['lr'], loss, dev_loss, dev_ppl, sparsity, elapsed_time)) if dev_ppl < best_dev: if (not args.prune ) or sparsity > args.prune_sparsity - 0.005: best_dev = dev_ppl checkpoint = copy_model(model_) sys.stdout.write("\n") sys.stdout.flush() nbatch += 1 if args.noam: lr = min(1.0 / (niter**0.5), niter / (args.warmup_steps**1.5)) optimizer.param_groups[0]['lr'] = lr * args.lr / (args.n_d** 0.5) if args.noam and start_prune: niter_ = niter - args.prune_start_epoch * N lr = min(1.0 / (niter_**0.5), niter_ / (args.warmup_steps**1.5)) optimizer_max.param_groups[0]['lr'] = -lr * args.prune_lr / ( args.n_d**0.5) optimizer_hc.param_groups[0]['lr'] = lr * args.lr / (args.n_d** 0.5) if local_rank == 0 and args.save and checkpoint is not None: torch.save(checkpoint, "{}.pt".format(args.save, )) if local_rank == 0: train_writer.close() dev_writer.close() if checkpoint is not None: model_.load_state_dict(checkpoint) model_.to(device) #dev = create_batches(dev_, 1) #test = create_batches(test_, 1) test = corpus.get_iterator('test', eval_batch_size, eval_unroll_size, device=device) dev_ppl, dev_loss = eval_model(model_, dev) test_ppl, test_loss = eval_model(model_, test) sys.stdout.write("dev_ppl={:.3f} test_ppl={:.3f}\n".format( dev_ppl, test_ppl))
if not args.cuda: print('WARNING: --fp16 requires --cuda, ignoring --fp16 option') args.fp16 = False else: try: from apex.fp16_utils import FP16_Optimizer except: print('WARNING: apex not installed, ignoring --fp16 option') args.fp16 = False device = torch.device('cuda' if args.cuda else 'cpu') ############################################################################### # Load data ############################################################################### corpus = get_lm_corpus(args.data, args.dataset) ntokens = len(corpus.vocab) args.n_token = ntokens eval_batch_size = 10 tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator('test',