def go(arg): if arg.seed < 0: seed = random.randint(0, 1000000) print('random seed: ', seed) else: torch.manual_seed(arg.seed) tbw = SummaryWriter(log_dir=arg.tb_dir) # Tensorboard logging # load the data arg.path = here('data') if arg.path is None else arg.path data_train, data_val, data_test = read_dataset(arg.path, arg.dataset) # create the model model = GTransformer(emb=arg.embedding_size, heads=arg.num_heads, depth=arg.depth, seq_length=arg.context, num_tokens=NUM_TOKENS, wide=arg.wide) if torch.cuda.is_available(): model.cuda() print("Model parameters = %d" % sum(p.numel() for p in model.parameters())) if not arg.radam: opt = torch.optim.Adam(lr=arg.lr, params=model.parameters()) # linear learning rate warmup sch = torch.optim.lr_scheduler.LambdaLR( opt, lambda i: min(i / (arg.lr_warmup / arg.batch_size), 1.0)) else: opt = RAdam(model.parameters(), lr=arg.lr) if USE_APEX: model, opt = amp.initialize(model, opt, opt_level="O1", verbosity=0) best_bpb = np.inf best_step = 0 # training loop # - note: we don't loop over the data, instead we sample a batch of random subsequences each time. for i in tqdm.trange(arg.num_batches): opt.zero_grad() # sample a batch of random subsequences starts = torch.randint(size=(arg.batch_size, ), low=0, high=data_train.size(0) - arg.context - 1) seqs_source = [ data_train[start:start + arg.context] for start in starts ] seqs_target = [ data_train[start + 1:start + arg.context + 1] for start in starts ] source = torch.cat([s[None, :] for s in seqs_source], dim=0).to(torch.long) target = torch.cat([s[None, :] for s in seqs_target], dim=0).to(torch.long) # - target is the same sequence as source, except one character ahead if torch.cuda.is_available(): source, target = source.cuda(), target.cuda() source, target = Variable(source), Variable(target) output = model(source) loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean') #tbw.add_scalar('transformer/train-loss', float(loss.item()) * LOG2E, i * arg.batch_size) if not USE_APEX: loss.backward() else: with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() # clip gradients # - If the total gradient vector has a length > 1, we clip it back down to 1. if arg.gradient_clipping > 0.0: nn.utils.clip_grad_norm_(model.parameters(), arg.gradient_clipping) opt.step() if not arg.radam: sch.step() # - validate every {arg.test_every} steps. First we compute the # compression on the validation (or a subset) # then we generate some random text to monitor progress if i != 0 and (i % arg.test_every == 0 or i == arg.num_batches - 1): upto = arg.test_subset if arg.test_subset else data_val.size(0) data_sub = data_val[:upto] bits_per_byte = calculate_bpb(arg, model, data_sub) # print validation performance. 1 bit per byte is (currently) state of the art. print(f'epoch{i}: {bits_per_byte:.4} bits per byte') tag_scalar_dict = { 'train-loss': float(loss.item()) * LOG2E, 'eval-loss': bits_per_byte } tbw.add_scalars(f'transformer/loss', tag_scalar_dict, i * arg.batch_size) if bits_per_byte < best_bpb: best_bpb = bits_per_byte best_step = i torch.save(model.state_dict(), os.path.join(arg.tb_dir, 'best_model.pt')) print(f'best step {best_step}: {best_bpb:.4} bits per byte') generate_sequence(arg, model, data_val) # load the best model, calculate bpb of the test data and generate some random text finalize(arg, model, data_test)
def go(arg): if arg.seed < 0: seed = random.randint(0, 1000000) print('random seed: ', seed) else: torch.manual_seed(arg.seed) tbw = SummaryWriter(log_dir=arg.tb_dir) # Tensorboard logging # load the data (validation unless arg.final is true, then test) arg.data = here('../wiki_uk.txt') if arg.data is None else arg.data data_train, data_val, data_test = ukwiki(arg.data) data_train, data_test = (torch.cat([data_train, data_val], dim=0), data_test) \ if arg.final else (data_train, data_val) # create the model model = GTransformer(emb=arg.embedding_size, heads=arg.num_heads, depth=arg.depth, seq_length=arg.context, num_tokens=NUM_TOKENS, wide=arg.wide) if os.path.exists(MODEL_PATH): model.load_state_dict(torch.load(MODEL_PATH)) if torch.cuda.is_available(): model.cuda() opt = torch.optim.Adam(lr=arg.lr, params=model.parameters()) # linear learning rate warmup sch = torch.optim.lr_scheduler.LambdaLR( opt, lambda i: min(i / (arg.lr_warmup / arg.batch_size), 1.0)) # training loop # - note: we don't loop over the data, instead we sample a batch of random subsequences each time. for i in tqdm.trange(arg.num_batches): opt.zero_grad() # sample a batch of random subsequences starts = torch.randint(size=(arg.batch_size, ), low=0, high=data_train.size(0) - arg.context - 1) if arg.masked: seqs_source = [ data_train.detach().clone()[start:start + arg.context, ] for start in starts ] seqs_target = [ data_train.detach().clone()[start:start + arg.context] for start in starts ] for ss, st in zip(seqs_source, seqs_target): mask_indexes = torch.randint(1, arg.context, (arg.error_count, )) for ind in mask_indexes: ss[ind] = torch.tensor(char_to_id['$']) # print(''.join([id_to_char[s.item()] for s in ss])) # print(''.join([id_to_char[t.item()] for t in st])) else: seqs_source = [ data_train[start:start + arg.context] for start in starts ] seqs_target = [ data_train[start + 1:start + arg.context + 1] for start in starts ] source = torch.cat([s[None, :] for s in seqs_source], dim=0).to(torch.long) target = torch.cat([s[None, :] for s in seqs_target], dim=0).to(torch.long) # - target is the same sequence as source, except one character ahead if torch.cuda.is_available(): source, target = source.cuda(), target.cuda() source, target = Variable(source), Variable(target) output = model(source) loss = F.nll_loss(output.transpose(2, 1), target, reduction='mean') tbw.add_scalar('transformer/train-loss', float(loss.item()) * LOG2E, i * arg.batch_size) loss.backward() # clip gradients # - If the total gradient vector has a length > 1, we clip it back down to 1. if arg.gradient_clipping > 0.0: nn.utils.clip_grad_norm_(model.parameters(), arg.gradient_clipping) opt.step() sch.step() # - validate every {arg.test_every} steps. First we compute the # compression on the validation (or a subset) # then we generate some random text to monitor progress if i != 0 and (i % arg.test_every == 0 or i == arg.num_batches - 1): upto = data_test.size( 0) if i == arg.num_batches - 1 else arg.test_subset data_sub = data_test[:upto] with torch.no_grad(): bits, tot = 0.0, 0 batch = [ ] # buffer, every time it fills up, we run it through the model # for current in range(data_sub.size(0)): # fr = max(0, current - arg.context) # to = current + 1 # context = data_sub[fr:to].to(torch.long) # if context.size(0) < arg.context + 1: # pad = torch.zeros(size=(arg.context + 1 - context.size(0),), dtype=torch.long) # context = torch.cat([pad, context], dim=0) # assert context.size(0) == arg.context + 1 # if torch.cuda.is_available(): # context = context.cuda() # batch.append(context[None, :]) # if len(batch) == arg.test_batchsize or current == data_sub.size(0) - 1: # # batch is full, run it through the model # b = len(batch) # all = torch.cat(batch, dim=0) # source = all[:, :-1] # input # target = all[:, -1] # target values # output = model(source) # lnprobs = output[torch.arange(b, device=d()), -1, target] # log2probs = lnprobs * LOG2E # convert from nats to bits # bits += - log2probs.sum() # batch = [] # empty buffer # bits_per_byte = bits / data_sub.size(0) # # print validation performance. 1 bit per byte is (currently) state of the art. # print(f'epoch{i}: {bits_per_byte:.4} bits per byte') # tbw.add_scalar(f'transformer/eval-loss', bits_per_byte, i * arg.batch_size) # generate some random text GENSIZE = 600 TEMP = 0.5 seedfr = random.randint(0, data_test.size(0) - arg.context) # input = data_test[seedfr:seedfr + arg.context].to(torch.long) test_msgs = [ "купила м$ма коника, а коник і шо", "як тебе не лю$ити Києве мій коли", "у л$сі лісі темному де ходить як" ] for test_msg in test_msgs: test_data = np.zeros(arg.context) test_data.fill(110) test_data[0:len(test_msg)] = np.array( [char_to_id[ch] for ch in test_msg]) input = torch.from_numpy(test_data).to(torch.long) if torch.cuda.is_available(): input = input.cuda() input = Variable(input) print('[', end='', flush=True) for c in input: print(str(id_to_char[c.item()]), end='', flush=True) print(']', end='', flush=True) output = model(input[None, :]) out_string = ''.join([ id_to_char[ind.item()] for ind in output[0].max(axis=1).indices ]) # c = sample(output[0].max(axis=1), TEMP) print("Foo1") print("PRED: " + out_string) print("Foo2") print() # Save model torch.save(model.state_dict(), MODEL_PATH)