def train_loop_fn(loader): tracker = xm.RateTracker() positions = torch.arange(SEQUENCE_LENGTH).long().view( 1, SEQUENCE_LENGTH).to(device) causal_mask = torch.triu( torch.ones( SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.uint8, device=device), diagonal=1).unsqueeze(0) model.train() for iteration, batch in enumerate(loader): input = batch[:, :-1].long() target = batch[:, 1:].long() loss = model(input, positions, target, batch_mask=causal_mask) loss.backward() xm.optimizer_step(optimizer) tracker.add(BATCH_SIZE) if iteration % LOG_STEPS == 0: print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format( device, iteration, loss.item() / math.log(2), tracker.rate())) if iteration % METRICS_STEP == 0: xm.master_print(met.metrics_report())
def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() tracker = xm.RateTracker() for x, batch in enumerate(loader): #batch = tuple(t.to(device) for t in batch) logits = model(*batch[:-1]) input_ids = batch[1] gold_ids = batch[2] for index in range(input_ids.shape[0]): length = input_ids[index] item = logits[index][:length] label = gold_ids[index,:length] total_samples += length for i in range(length): if item[i] == label[i]: correct += 1 if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print('[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, correct*1.0/total_samples, tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) accuracy = 100.0 * correct / total_samples if xm.get_ordinal() == 0: print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty data, target = data[:-index, :, :, :], target[:-index] if step >= 15 and training_started: # testing purpose only: set event for synchronization. training_started.set() with xp.StepTrace('train_mnist', step_num=step): with xp.Trace('build_graph'): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) if fetch_often: # testing purpose only: fetch XLA tensors to CPU. loss_i = loss.item() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=lr_scheduler_type, scheduler_divisor=lr_scheduler_divisor, scheduler_divide_every_n_epochs= lr_scheduler_divide_every_n_epochs, num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() data = data.permute(0, 3, 1, 2) output = model(data) print('passed through model') loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(batch_size) if x % log_steps == 0: print( 'device: {}, x: {}, loss: {}, tracker: {}, tracker_global: {} ' .format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) if lr_scheduler: lr_scheduler.step()
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() total_samples = 0 correct = 0 top5_accuracys = 0 losses = 0 for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() losses += loss.item() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5).item() if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) return ( losses / (x + 1), (100.0 * correct / total_samples), (top5_accuracys / (x + 1)), )
def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) if lr_scheduler: lr_scheduler.step()
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): batch = tuple(t.to(device) for t in batch) # loss = self.criterion(logits, batch[-1]) start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model( *batch) loss1 = criterion(start_logits, batch[6]) + criterion( end_logits, batch[7]) # y1, y2 loss2 = config.type_lambda * criterion(type_logits, batch[8]) # q_type # sent_num_in_batch = batch[9].sum() # is_support # sent_num_in_batch = 1.0 + sent_num_in_batch # to avoid devide by zero # loss3 = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum() * self.config.sp_lambda / sent_num_in_batch loss = loss1 + loss2 loss.backward() del batch #try to save cpu mem. tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Loss1={:.5f} Loss2={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss1.item(), loss2.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, batch in enumerate(loader): # batch = tuple(t.to(self.device) for t in batch) output = model(*batch[:-1]) # the last one is label loss = criterion(output, batch[-1]) loss.backward() # xm.optimizer_step(optimizer) # optimizer.zero_grad() tracker.add(FLAGS.batch_size) if (x + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。 xm.optimizer_step(optimizer) optimizer.zero_grad() if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None tracker = xm.RateTracker() for x, batch in enumerate(loader): output = model(*batch[:-1]) # the last one is label target = batch[-1] # pred = output.max(1, keepdim=True)[1] # correct += pred.eq(target.view_as(pred)).sum().item() for i in range(len(output)): logits = output[i] pred = int(torch.argmax(logits, dim=-1)) if pred == target[i]: correct += 1 total_samples += len(output) if xm.get_ordinal() == 0: if x % FLAGS.log_steps == 0: print( '[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, correct * 1.0 / total_samples, tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) accuracy = 100.0 * correct / total_samples if xm.get_ordinal() == 0: print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() import resource print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}") if step % FLAGS.log_steps == 0: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer) )
def tpu_training_loop(model, loader, device, context): """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch""" param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch optimizer = context.getattr_or( 'optimizer', AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=tuple(args.betas))) # derive warmup info if args.warmup_proportion is not None: warmup_steps = int(args.warmup_proportion * num_train_optimization_steps + 0.5) elif args.warmup_steps is not None: warmup_steps = args.warmup_steps else: raise Exception('What is the warmup?? Specify either warmup proportion or steps') scheduler = context.getattr_or( 'scheduler', WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)) tr_loss = None pbar = None if str(pbar_device) == str(device): # All threads are in sync. Use progress bar only on one of them pbar = tqdm(total=int(pbar_steps), desc=f"device {device}", dynamic_ncols=True) tracker = tpu_xm.RateTracker() model.train() for step, batch in enumerate(loader): input_ids, input_mask, segment_ids, lm_label_ids, _ = batch outputs = model(input_ids, segment_ids, input_mask, lm_label_ids) loss = outputs[0] if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.sum().backward() # for multiple tensors tracker.add(args.train_batch_size) tr_loss = loss * args.gradient_accumulation_steps if step == 0 else tr_loss + loss * args.gradient_accumulation_steps if pbar is not None: pbar.update(1) # pbar.set_description(desc=f'LR: {scheduler.get_lr()}') if (step + 1) % args.gradient_accumulation_steps == 0: tpu_xm.optimizer_step(optimizer) prev_lr = scheduler.get_last_lr()[0] scheduler.step() curr_lr = scheduler.get_last_lr()[0] if args.track_learning_rate: if pbar is not None: pbar.set_description(f"Prev LR: {prev_lr} Curr LR: {curr_lr}") optimizer.zero_grad() return tr_loss.sum().item() / step # `.item()` requires a trip from TPU to CPU, which is very slow. Use it only once per epoch=
def predict(model, dataloader, example_dict, feature_dict, prediction_file, test_loss_record, need_sp_logit_file=False): model.eval() answer_dict = {} sp_dict = {} dataloader.refresh() total_test_loss = [0] * 5 tracker = xm.RateTracker() for x,batch in enumerate(dataloader): batch['context_mask'] = batch['context_mask'].float() start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(batch) loss_list = compute_loss(batch, start_logits, end_logits, type_logits, sp_logits, start_position, end_position) for i, l in enumerate(loss_list): if not isinstance(l, int): total_test_loss[i] += l.item() answer_dict_ = convert_to_tokens(example_dict, feature_dict, batch['ids'], start_position.data.cpu().numpy().tolist(), end_position.data.cpu().numpy().tolist(), np.argmax(type_logits.data.cpu().numpy(), 1)) answer_dict.update(answer_dict_) predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy() for i in range(predict_support_np.shape[0]): cur_sp_pred = [] cur_id = batch['ids'][i] cur_sp_logit_pred = [] # for sp logit output for j in range(predict_support_np.shape[1]): if j >= len(example_dict[cur_id].sent_names): break if need_sp_logit_file: temp_title, temp_id = example_dict[cur_id].sent_names[j] cur_sp_logit_pred.append((temp_title, temp_id, predict_support_np[i, j])) if predict_support_np[i, j] > args.sp_threshold: cur_sp_pred.append(example_dict[cur_id].sent_names[j]) sp_dict.update({cur_id: cur_sp_pred}) if xm.get_ordinal() == 0: if x % 100 == 0: print('[xla:{}]({}) =={}== Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, ','.join([str(epoch)] + [str(results[s]) for s in keys]), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True) new_answer_dict={} for key,value in answer_dict.items(): new_answer_dict[key]=value.replace(" ","") prediction = {'answer': new_answer_dict, 'sp': sp_dict} with open(prediction_file, 'w',encoding='utf8') as f: json.dump(prediction, f,indent=4,ensure_ascii=False) for i, l in enumerate(total_test_loss): print("Test Loss{}: {}".format(i, l / len(dataloader))) test_loss_record.append(sum(total_test_loss[:3]) / len(dataloader))
def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate())
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data, x) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: _train_update(device, step, loss, tracker, epoch, writer)
def train_loop(data_loader, writes=0): if torch.cuda.is_available(): torch.cuda.synchronize(device=config.device) train_loss = 0. last_train_loss = 0. new_writes = 0 time_ = time.time() if config.use_tpu: tracker = xm.RateTracker() model.train() for batch_idx, (input, _) in enumerate(data_loader): input = input.to(config.device, non_blocking=True) if config.noising_factor is not None: false_input = input + config.noising_factor * config.noise_function( input.shape) false_input.clamp_(min=-1, max=1) output = model(false_input) else: print(f'noising factor in none') output = model(input) loss = loss_function(input, output) optimizer.zero_grad() loss.backward() if config.use_tpu: xm.optimizer_step(optimizer) tracker.add(config.batch_size) else: optimizer.step() train_loss += loss if config.print_every and (batch_idx + 1) % config.print_every == 0: print(f'this prints every {config.print_every} times') deno = config.print_every * config.batch_size * np.prod( input_shape) * np.log(2.) if not config.use_tpu: writer.add_scalar('train/bpd', (train_loss / deno), writes + new_writes) print('\t{:3d}/{:3d} - loss : {:.4f}, time : {:.3f}s'.format( batch_idx // config.print_every + 1, len(train_loader) // config.print_every, (train_loss / deno), (time.time() - time_))) last_train_loss = train_loss train_loss = 0. new_writes += 1 time_ = time.time() del input, _, loss, output return new_writes, (last_train_loss / deno)
def train_loop_fn(model, loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure( _train_update, args=(device, step, loss, tracker, writer), run_async=FLAGS.async_closures)
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, x, loss, tracker))
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() # do not reduce gradients on sharded params tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer))
def train_loop_fn(loader, net, optimizer, loss_fn, batch_size, log_steps): tracker = xm.RateTracker() net.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() out = net(data) loss = loss_fn(out, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(batch_size) if x % log_steps == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def train_bert(model_name, amp_enabled, xla_enabled, dataset_path, num_examples=500): tokenizer, model = generate_tokenizer_and_model(model_name) train_texts, train_labels = read_imdb_split(os.path.join(dataset_path, 'train')) test_texts, test_labels = read_imdb_split(os.path.join(dataset_path,'test')) train_texts, train_labels = train_texts[:num_examples], train_labels[:num_examples] test_texts, test_labels = test_texts[:num_examples], test_labels[:num_examples] train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2) train_encodings = tokenizer(train_texts, truncation=True, padding=True) val_encodings = tokenizer(val_texts, truncation=True, padding=True) test_encodings = tokenizer(test_texts, truncation=True, padding=True) train_dataset = IMDbDataset(train_encodings, train_labels) val_dataset = IMDbDataset(val_encodings, val_labels) test_dataset = IMDbDataset(test_encodings, test_labels) if xla_enabled: device = xm.xla_device() else: device = torch.device("cuda") model.to(device) model.train() train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) optim = AdamW(model.parameters(), lr=5e-5) if amp_enabled: autocast, scaler = get_autocast_and_scaler(xla_enabled) tracker = xm.RateTracker() for epoch in range(3): for step, batch in enumerate(train_loader): optim.zero_grad() input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) if amp_enabled: loss, optim = loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler) else: loss, optim = loop_without_amp(model, input_ids, attention_mask, labels, optim, xla_enabled) tracker.add(input_ids.shape[0]) _train_update(device, step, loss, tracker, epoch, None)
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def train_loop_fn(per_device_loader, e): tracker = xm.RateTracker() model.train() total_loss = 0.0 total_steps = 0 for step, batch in enumerate(per_device_loader): batch = tuple(t.to(device) for t in batch) input_ids, segment_ids, input_mask, next_sentence_labels, label_ids = batch optimizer.zero_grad() masked_lm_logists, auxiliary_logits = model( input_ids, segment_ids, input_mask) # Is the operation of the view method different?? masked_lm_loss = criterion_lm( masked_lm_logists.view(-1, label_ids.size(1)).transpose(0, 1), label_ids.view(-1)) # mlm if auxiliary_logits is None: loss = masked_lm_loss # nsp else: loss = masked_lm_loss + criterion_ns( auxiliary_logits.view(-1, 2), next_sentence_labels.view(-1)) loss.backward() total_steps += 1 total_loss += loss.item() tracker.add(batch_size) if metrics_debug: if step % cli_interval == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f}' .format(xm.get_ordinal(), step, loss.item(), tracker.rate(), tracker.global_rate()), flush=True) if step % per_save_steps == 0: output_model_file = os.path.join(save_dir, "train_model.pt") if xm.get_ordinal() == 0: save(model, output_model_file, optimizer) xm.optimizer_step(optimizer) scheduler.step()
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, label) in enumerate(loader): optimizer.zero_grad() output = model(image=data, label=label, get_embedding=args.get_embeddings) loss = loss_fn(output, label) loss.backward() xm.optimizer_step(optimizer) tracker.add(args.batch_size) if x % 20 == 0: print( '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}' .format(xm.get_ordinal(), x, loss.item(), tracker.rate(), tracker.global_rate(), time.asctime()), flush=True)
def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.Adam(model.parameters(), lr=lr)) tracker = xm.RateTracker() model.train() print('# of iterations: {}'.format(maxItr)) logger.info('# of iterations: {}'.format(maxItr)) optimizer.zero_grad() for x, (data, target) in enumerate(loader): data = target[0].permute(0,3,1,2) target = target[1] output = model(data) loss = loss_fn(output, target.long()) #_, preds = torch.max(output, 1) loss.backward() # backprop every log_step iterations if x % log_steps == 0: xm.optimizer_step(optimizer) optimizer.zero_grad() tracker.add(batch_size) # compute the confusion matrix and IoU #print(preds.shape) #print(target.shape) #val_conf = np.zeros((num_classes, num_classes)) #val_conf = val_conf + confusion_matrix( # target[target >= 0].view(-1).cpu().numpy(), # preds[target >= 0].view(-1).cpu().numpy()) #pos = np.sum(val_conf, 1) #res = np.sum(val_conf, 0) #tp = np.diag(val_conf) #iou = np.mean(tp / np.maximum(1, pos + res - tp)) #logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) print('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) if x % log_steps == 0: logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))
def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate())
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): data = data.to(device) target = target.to(device) optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # optimizer.step() # xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if step % FLAGS.log_steps == 0: _train_update(device, step, loss, tracker, epoch, writer)
def train(embedder, model, optimizer, trainloader, writer, logger, epoch, pt_dir,device): try: tracker = xm.RateTracker() criterion = nn.MSELoss() model.train() step = 0 for batch_idx, (dvec_mel, target_mag, mixed_mag) in enumerate(trainloader): target_mag, mixed_mag = target_mag.to(device), mixed_mag.to(device) dvec_list = list() for mel in dvec_mel: mel = mel.to(device) dvec = embedder(mel) dvec_list.append(dvec) dvec = torch.stack(dvec_list, dim=0) dvec = dvec.detach() #mask model optimizer.zero_grad() mask = model(mixed_mag, dvec) output = mixed_mag * mask #calculate loss, the paper says it use powerlaw, but we don't do it here loss = criterion(output, target_mag) loss.backward() xm.optimizer_step(optimizer) tracker.add(len(output)) loss = loss.item() #log step += len(output) logger.info('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format( xm.get_ordinal(), batch_idx, loss, tracker.rate(), tracker.global_rate(), time.asctime())) if step % config.train['ckpt_interval'] == 0 : model_saver(model,optimizer,pt_dir,epoch) logger.info("Saved Checkpoint at Epoch%d,Step%d" % (epoch, step)) except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()