Ejemplo n.º 1
0
  def train_loop_fn(loader):
    tracker = xm.RateTracker()

    positions = torch.arange(SEQUENCE_LENGTH).long().view(
        1, SEQUENCE_LENGTH).to(device)
    causal_mask = torch.triu(
        torch.ones(
            SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.uint8, device=device),
        diagonal=1).unsqueeze(0)

    model.train()
    for iteration, batch in enumerate(loader):
      input = batch[:, :-1].long()
      target = batch[:, 1:].long()

      loss = model(input, positions, target, batch_mask=causal_mask)
      loss.backward()
      xm.optimizer_step(optimizer)

      tracker.add(BATCH_SIZE)
      if iteration % LOG_STEPS == 0:
        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
            device, iteration,
            loss.item() / math.log(2), tracker.rate()))
      if iteration % METRICS_STEP == 0:
        xm.master_print(met.metrics_report())
Ejemplo n.º 2
0
    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        tracker = xm.RateTracker()

        for x, batch in enumerate(loader):
            #batch = tuple(t.to(device) for t in batch)
            logits = model(*batch[:-1])
            input_ids = batch[1]
            gold_ids = batch[2]
            for index in range(input_ids.shape[0]):
                length = input_ids[index]
                item = logits[index][:length]
                label = gold_ids[index,:length]
                total_samples += length
                for i in range(length):
                    if item[i] == label[i]:
                        correct += 1
            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print('[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                        xm.get_ordinal(), x, correct*1.0/total_samples, tracker.rate(),
                        tracker.global_rate(), time.asctime()), flush=True)

        accuracy = 100.0 * correct / total_samples

        if xm.get_ordinal() == 0:
            print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(), accuracy), flush=True)
        return accuracy, data, pred, target
Ejemplo n.º 3
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if dynamic_graph:
                # testing purpose only: dynamic batch size and graph.
                index = max(-step, -flags.batch_size + 1)  # non-empty
                data, target = data[:-index, :, :, :], target[:-index]
            if step >= 15 and training_started:
                # testing purpose only: set event for synchronization.
                training_started.set()

            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)
                if fetch_often:
                    # testing purpose only: fetch XLA tensors to CPU.
                    loss_i = loss.item()
                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))
Ejemplo n.º 4
0
 def train_loop_fn(model, loader, device, context):
     loss_fn = nn.CrossEntropyLoss()
     optimizer = context.getattr_or(
         'optimizer', lambda: optim.SGD(model.parameters(),
                                        lr=lr,
                                        momentum=momentum,
                                        weight_decay=1e-4))
     lr_scheduler = context.getattr_or(
         'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
             optimizer,
             scheduler_type=lr_scheduler_type,
             scheduler_divisor=lr_scheduler_divisor,
             scheduler_divide_every_n_epochs=
             lr_scheduler_divide_every_n_epochs,
             num_steps_per_epoch=num_training_steps_per_epoch,
             summary_writer=None))
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in loader:
         optimizer.zero_grad()
         data = data.permute(0, 3, 1, 2)
         output = model(data)
         print('passed through model')
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(batch_size)
         if x % log_steps == 0:
             print(
                 'device: {}, x: {}, loss: {}, tracker: {}, tracker_global: {} '
                 .format(device, x, loss.item(), tracker.rate(),
                         tracker.global_rate()))
         if lr_scheduler:
             lr_scheduler.step()
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     total_samples = 0
     correct = 0
     top5_accuracys = 0
     losses = 0
     for x, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         pred = output.max(1, keepdim=True)[1]
         correct += pred.eq(target.view_as(pred)).sum().item()
         losses += loss.item()
         total_samples += data.size()[0]
         top5_accuracys += topk_accuracy(output, target, topk=5).item()
         if lr_scheduler:
             lr_scheduler.step()
         if x % FLAGS.log_steps == 0:
             test_utils.print_training_update(device, x, loss.item(),
                                              tracker.rate(),
                                              tracker.global_rate())
     return (
         losses / (x + 1),
         (100.0 * correct / total_samples),
         (top5_accuracys / (x + 1)),
     )
Ejemplo n.º 6
0
 def train_loop_fn(model, loader, device, context):
     loss_fn = nn.CrossEntropyLoss()
     optimizer = context.getattr_or(
         'optimizer', lambda: optim.SGD(model.parameters(),
                                        lr=FLAGS.lr,
                                        momentum=FLAGS.momentum,
                                        weight_decay=5e-4))
     lr_scheduler = context.getattr_or(
         'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
             optimizer,
             scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
             scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
             scheduler_divide_every_n_epochs=getattr(
                 FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
             num_steps_per_epoch=num_training_steps_per_epoch,
             summary_writer=writer if xm.is_master_ordinal() else None))
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in loader:
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         if x % FLAGS.log_steps == 0:
             test_utils.print_training_update(device, x, loss.item(),
                                              tracker.rate(),
                                              tracker.global_rate())
         if lr_scheduler:
             lr_scheduler.step()
Ejemplo n.º 7
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, batch in enumerate(loader):
            batch = tuple(t.to(device) for t in batch)
            # loss = self.criterion(logits, batch[-1])
            start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(
                *batch)
            loss1 = criterion(start_logits, batch[6]) + criterion(
                end_logits, batch[7])  # y1, y2
            loss2 = config.type_lambda * criterion(type_logits,
                                                   batch[8])  # q_type
            # sent_num_in_batch = batch[9].sum()  # is_support
            # sent_num_in_batch = 1.0 + sent_num_in_batch # to avoid devide by zero
            # loss3 = self.sp_loss_fct(sp_logits.view(-1), batch[10].float().view(-1)).sum() * self.config.sp_lambda / sent_num_in_batch
            loss = loss1 + loss2
            loss.backward()
            del batch  #try to save cpu mem.
            tracker.add(FLAGS.batch_size)
            if (x + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm)
                # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                xm.optimizer_step(optimizer)
                optimizer.zero_grad()

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print(
                        '[xla:{}]({}) Loss1={:.5f} Loss2={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                        .format(xm.get_ordinal(), x, loss1.item(),
                                loss2.item(), tracker.rate(),
                                tracker.global_rate(), time.asctime()),
                        flush=True)
Ejemplo n.º 8
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, batch in enumerate(loader):
            # batch = tuple(t.to(self.device) for t in batch)
            output = model(*batch[:-1])  # the last one is label
            loss = criterion(output, batch[-1])
            loss.backward()
            # xm.optimizer_step(optimizer)
            # optimizer.zero_grad()

            tracker.add(FLAGS.batch_size)
            if (x + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               config.max_grad_norm)
                # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                xm.optimizer_step(optimizer)
                optimizer.zero_grad()

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print(
                        '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                        .format(xm.get_ordinal(), x, loss.item(),
                                tracker.rate(), tracker.global_rate(),
                                time.asctime()),
                        flush=True)
Ejemplo n.º 9
0
    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        data, pred, target = None, None, None
        tracker = xm.RateTracker()
        for x, batch in enumerate(loader):
            output = model(*batch[:-1])  # the last one is label
            target = batch[-1]
            # pred = output.max(1, keepdim=True)[1]
            # correct += pred.eq(target.view_as(pred)).sum().item()
            for i in range(len(output)):
                logits = output[i]
                pred = int(torch.argmax(logits, dim=-1))
                if pred == target[i]:
                    correct += 1
            total_samples += len(output)

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print(
                        '[xla:{}]({}) Acc={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                        .format(xm.get_ordinal(), x,
                                correct * 1.0 / total_samples, tracker.rate(),
                                tracker.global_rate(), time.asctime()),
                        flush=True)

        accuracy = 100.0 * correct / total_samples
        if xm.get_ordinal() == 0:
            print('[xla:{}] Accuracy={:.2f}%'.format(xm.get_ordinal(),
                                                     accuracy),
                  flush=True)
        return accuracy, data, pred, target
Ejemplo n.º 10
0
    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()

            import resource
            print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")

            if step % FLAGS.log_steps == 0:
                # _train_update(device, step, loss, tracker, epoch, writer)
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer)
                )
Ejemplo n.º 11
0
    def tpu_training_loop(model, loader, device, context):
        """ Called by torch_xla_py.data_parallel. This function is executed on each core of the TPU once per epoch"""

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        # one optimizer and scheduler per TPU core. Both objects are saved in `context` to be reused the next epoch
        optimizer = context.getattr_or(
            'optimizer',
            AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=tuple(args.betas)))

        # derive warmup info
        if args.warmup_proportion is not None:
            warmup_steps = int(args.warmup_proportion * num_train_optimization_steps + 0.5)
        elif args.warmup_steps is not None:
            warmup_steps = args.warmup_steps
        else:
            raise Exception('What is the warmup?? Specify either warmup proportion or steps')
        scheduler = context.getattr_or(
            'scheduler',
            WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps))

        tr_loss = None
        pbar = None
        if str(pbar_device) == str(device):  # All threads are in sync. Use progress bar only on one of them
            pbar = tqdm(total=int(pbar_steps), desc=f"device {device}", dynamic_ncols=True)

        tracker = tpu_xm.RateTracker()

        model.train()

        for step, batch in enumerate(loader):
            input_ids, input_mask, segment_ids, lm_label_ids, _ = batch
            outputs = model(input_ids, segment_ids, input_mask, lm_label_ids)
            loss = outputs[0]
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            
            loss.sum().backward() # for multiple tensors
            tracker.add(args.train_batch_size)

            tr_loss = loss * args.gradient_accumulation_steps if step == 0 else  tr_loss + loss * args.gradient_accumulation_steps
            if pbar is not None:
                pbar.update(1)
                # pbar.set_description(desc=f'LR: {scheduler.get_lr()}')
            if (step + 1) % args.gradient_accumulation_steps == 0:
                tpu_xm.optimizer_step(optimizer)
                prev_lr = scheduler.get_last_lr()[0]
                scheduler.step()
                curr_lr = scheduler.get_last_lr()[0]
                if args.track_learning_rate:
                    if pbar is not None:
                        pbar.set_description(f"Prev LR: {prev_lr} Curr LR: {curr_lr}")
                optimizer.zero_grad()
        return tr_loss.sum().item() / step  # `.item()` requires a trip from TPU to CPU, which is very slow. Use it only once per epoch=
Ejemplo n.º 12
0
def predict(model, dataloader, example_dict, feature_dict, prediction_file, test_loss_record, need_sp_logit_file=False):

    model.eval()
    answer_dict = {}
    sp_dict = {}
    dataloader.refresh()
    total_test_loss = [0] * 5
    tracker = xm.RateTracker()

    for x,batch in enumerate(dataloader):

        batch['context_mask'] = batch['context_mask'].float()
        start_logits, end_logits, type_logits, sp_logits, start_position, end_position = model(batch)

        loss_list = compute_loss(batch, start_logits, end_logits, type_logits, sp_logits, start_position, end_position)

        for i, l in enumerate(loss_list):
            if not isinstance(l, int):
                total_test_loss[i] += l.item()


        answer_dict_ = convert_to_tokens(example_dict, feature_dict, batch['ids'], start_position.data.cpu().numpy().tolist(),
                                         end_position.data.cpu().numpy().tolist(), np.argmax(type_logits.data.cpu().numpy(), 1))
        answer_dict.update(answer_dict_)

        predict_support_np = torch.sigmoid(sp_logits).data.cpu().numpy()
        for i in range(predict_support_np.shape[0]):
            cur_sp_pred = []
            cur_id = batch['ids'][i]

            cur_sp_logit_pred = []  # for sp logit output
            for j in range(predict_support_np.shape[1]):
                if j >= len(example_dict[cur_id].sent_names):
                    break
                if need_sp_logit_file:
                    temp_title, temp_id = example_dict[cur_id].sent_names[j]
                    cur_sp_logit_pred.append((temp_title, temp_id, predict_support_np[i, j]))
                if predict_support_np[i, j] > args.sp_threshold:
                    cur_sp_pred.append(example_dict[cur_id].sent_names[j])
            sp_dict.update({cur_id: cur_sp_pred})

        if xm.get_ordinal() == 0:
            if x % 100 == 0:
                print('[xla:{}]({}) =={}== Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                    xm.get_ordinal(), x, ','.join([str(epoch)] + [str(results[s]) for s in keys]), tracker.rate(),
                    tracker.global_rate(), time.asctime()), flush=True)

    new_answer_dict={}
    for key,value in answer_dict.items():
        new_answer_dict[key]=value.replace(" ","")
    prediction = {'answer': new_answer_dict, 'sp': sp_dict}
    with open(prediction_file, 'w',encoding='utf8') as f:
        json.dump(prediction, f,indent=4,ensure_ascii=False)

    for i, l in enumerate(total_test_loss):
        print("Test Loss{}: {}".format(i, l / len(dataloader)))
    test_loss_record.append(sum(total_test_loss[:3]) / len(dataloader))
Ejemplo n.º 13
0
    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
Ejemplo n.º 14
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(flags.batch_size)
         if step % flags.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, writer))
Ejemplo n.º 15
0
  def train_loop_fn(loader):
    tracker = xm.RateTracker()

    model.train()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                         tracker.global_rate())
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data, x)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS['batch_size'])
         if x % FLAGS['log_steps'] == 0:
             print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                 xm.get_ordinal(), x, loss.item(), tracker.rate(),
                 tracker.global_rate(), time.asctime()), flush=True)
Ejemplo n.º 17
0
 def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         if lr_scheduler:
             lr_scheduler.step()
         if step % FLAGS.log_steps == 0:
             _train_update(device, step, loss, tracker, epoch, writer)
Ejemplo n.º 18
0
    def train_loop(data_loader, writes=0):
        if torch.cuda.is_available():
            torch.cuda.synchronize(device=config.device)
        train_loss = 0.
        last_train_loss = 0.
        new_writes = 0
        time_ = time.time()
        if config.use_tpu:
            tracker = xm.RateTracker()
        model.train()
        for batch_idx, (input, _) in enumerate(data_loader):
            input = input.to(config.device, non_blocking=True)
            if config.noising_factor is not None:
                false_input = input + config.noising_factor * config.noise_function(
                    input.shape)
                false_input.clamp_(min=-1, max=1)
                output = model(false_input)
            else:
                print(f'noising factor in none')
                output = model(input)
            loss = loss_function(input, output)
            optimizer.zero_grad()
            loss.backward()
            if config.use_tpu:
                xm.optimizer_step(optimizer)
                tracker.add(config.batch_size)
            else:
                optimizer.step()
            train_loss += loss
            if config.print_every and (batch_idx +
                                       1) % config.print_every == 0:
                print(f'this prints every {config.print_every} times')
                deno = config.print_every * config.batch_size * np.prod(
                    input_shape) * np.log(2.)
                if not config.use_tpu:
                    writer.add_scalar('train/bpd', (train_loss / deno),
                                      writes + new_writes)

                print('\t{:3d}/{:3d} - loss : {:.4f}, time : {:.3f}s'.format(
                    batch_idx // config.print_every + 1,
                    len(train_loader) // config.print_every,
                    (train_loss / deno), (time.time() - time_)))
                last_train_loss = train_loss
                train_loss = 0.
                new_writes += 1
                time_ = time.time()
            del input, _, loss, output

        return new_writes, (last_train_loss / deno)
Ejemplo n.º 19
0
 def train_loop_fn(model, loader):
   tracker = xm.RateTracker()
   model.train()
   for step, (data, target) in enumerate(loader):
     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()
     optimizer.step()  # do not reduce gradients on sharded params
     tracker.add(flags.batch_size)
     if step % flags.log_steps == 0:
       xm.add_step_closure(
           _train_update,
           args=(device, step, loss, tracker, writer),
           run_async=FLAGS.async_closures)
Ejemplo n.º 20
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for x, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(FLAGS.batch_size)
         if lr_scheduler:
             lr_scheduler.step()
         if x % FLAGS.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, x, loss, tracker))
Ejemplo n.º 21
0
 def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         optimizer.step()  # do not reduce gradients on sharded params
         tracker.add(FLAGS.batch_size)
         if lr_scheduler:
             lr_scheduler.step()
         if step % FLAGS.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, epoch,
                                       writer))
Ejemplo n.º 22
0
def train_loop_fn(loader, net, optimizer, loss_fn, batch_size, log_steps):
    tracker = xm.RateTracker()
    net.train()
    for x, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        out = net(data)
        loss = loss_fn(out, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        tracker.add(batch_size)
        if x % log_steps == 0:
            print(
                '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                .format(xm.get_ordinal(), x, loss.item(), tracker.rate(),
                        tracker.global_rate(), time.asctime()),
                flush=True)
Ejemplo n.º 23
0
def train_bert(model_name, amp_enabled, xla_enabled, dataset_path, num_examples=500):
    tokenizer, model = generate_tokenizer_and_model(model_name)

    train_texts, train_labels = read_imdb_split(os.path.join(dataset_path, 'train'))
    test_texts, test_labels = read_imdb_split(os.path.join(dataset_path,'test'))

    train_texts, train_labels = train_texts[:num_examples], train_labels[:num_examples]
    test_texts, test_labels = test_texts[:num_examples], test_labels[:num_examples]

    train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)


    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    val_encodings = tokenizer(val_texts, truncation=True, padding=True)
    test_encodings = tokenizer(test_texts, truncation=True, padding=True)

    train_dataset = IMDbDataset(train_encodings, train_labels)
    val_dataset = IMDbDataset(val_encodings, val_labels)
    test_dataset = IMDbDataset(test_encodings, test_labels)

    if xla_enabled:
        device = xm.xla_device()
    else:
        device = torch.device("cuda")
    model.to(device)
    model.train()

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    optim = AdamW(model.parameters(), lr=5e-5)

    if amp_enabled: 
        autocast, scaler = get_autocast_and_scaler(xla_enabled)
    
    tracker = xm.RateTracker()
    for epoch in range(3):
        for step, batch in enumerate(train_loader):
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            if amp_enabled:
                loss, optim = loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler)
            else:
                loss, optim = loop_without_amp(model, input_ids, attention_mask, labels, optim, xla_enabled)
            tracker.add(input_ids.shape[0])
            _train_update(device, step, loss, tracker, epoch, None)
Ejemplo n.º 24
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         with autocast():
             output = model(data)
             loss = loss_fn(output, target)
         scaler.scale(loss).backward()
         gradients = xm._fetch_gradients(optimizer)
         xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
         scaler.step(optimizer)
         scaler.update()
         tracker.add(flags.batch_size)
         if step % flags.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, writer))
Ejemplo n.º 25
0
        def train_loop_fn(per_device_loader, e):
            tracker = xm.RateTracker()
            model.train()
            total_loss = 0.0
            total_steps = 0
            for step, batch in enumerate(per_device_loader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, segment_ids, input_mask, next_sentence_labels, label_ids = batch
                optimizer.zero_grad()
                masked_lm_logists, auxiliary_logits = model(
                    input_ids, segment_ids, input_mask)

                # Is the operation of the view method different??
                masked_lm_loss = criterion_lm(
                    masked_lm_logists.view(-1,
                                           label_ids.size(1)).transpose(0, 1),
                    label_ids.view(-1))
                # mlm
                if auxiliary_logits is None:
                    loss = masked_lm_loss
                # nsp
                else:
                    loss = masked_lm_loss + criterion_ns(
                        auxiliary_logits.view(-1, 2),
                        next_sentence_labels.view(-1))
                loss.backward()

                total_steps += 1
                total_loss += loss.item()
                tracker.add(batch_size)
                if metrics_debug:
                    if step % cli_interval == 0:
                        print(
                            '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f}'
                            .format(xm.get_ordinal(), step, loss.item(),
                                    tracker.rate(), tracker.global_rate()),
                            flush=True)
                if step % per_save_steps == 0:
                    output_model_file = os.path.join(save_dir,
                                                     "train_model.pt")
                    if xm.get_ordinal() == 0:
                        save(model, output_model_file, optimizer)

                xm.optimizer_step(optimizer)
                scheduler.step()
Ejemplo n.º 26
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for x, (data, label) in enumerate(loader):
         optimizer.zero_grad()
         output = model(image=data,
                        label=label,
                        get_embedding=args.get_embeddings)
         loss = loss_fn(output, label)
         loss.backward()
         xm.optimizer_step(optimizer)
         tracker.add(args.batch_size)
         if x % 20 == 0:
             print(
                 '[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'
                 .format(xm.get_ordinal(), x, loss.item(), tracker.rate(),
                         tracker.global_rate(), time.asctime()),
                 flush=True)
Ejemplo n.º 27
0
  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = context.getattr_or(
        'optimizer',
        lambda: optim.Adam(model.parameters(), lr=lr))
    tracker = xm.RateTracker()

    model.train()
    print('# of iterations: {}'.format(maxItr))
    logger.info('# of iterations: {}'.format(maxItr))
    optimizer.zero_grad()
    for x, (data, target) in enumerate(loader):
      data = target[0].permute(0,3,1,2)
      target = target[1]
      output = model(data)
      loss = loss_fn(output, target.long())
      #_, preds = torch.max(output, 1)
      loss.backward()
      
      # backprop every log_step iterations
      if x % log_steps == 0:
        xm.optimizer_step(optimizer)
        optimizer.zero_grad()

      tracker.add(batch_size)

      # compute the confusion matrix and IoU
      #print(preds.shape)
      #print(target.shape)
      
      #val_conf = np.zeros((num_classes, num_classes))
      #val_conf = val_conf + confusion_matrix(
      #    target[target >= 0].view(-1).cpu().numpy(),
      #    preds[target >= 0].view(-1).cpu().numpy())
      #pos = np.sum(val_conf, 1)
      #res = np.sum(val_conf, 0)
      #tp = np.diag(val_conf)
      #iou = np.mean(tp / np.maximum(1, pos + res - tp))

      #logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))
      print('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))
      
      if x % log_steps == 0:
        logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))
Ejemplo n.º 28
0
    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
Ejemplo n.º 29
0
    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            data = data.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()
            # xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                _train_update(device, step, loss, tracker, epoch, writer)
Ejemplo n.º 30
0
def train(embedder, model, optimizer, trainloader, writer, logger, epoch, pt_dir,device):
    try:
        tracker = xm.RateTracker()
        criterion = nn.MSELoss()
        model.train()
        step = 0
        for batch_idx, (dvec_mel, target_mag, mixed_mag) in enumerate(trainloader):
            target_mag, mixed_mag = target_mag.to(device), mixed_mag.to(device)

            dvec_list = list()
            for mel in dvec_mel:
                mel = mel.to(device)
                dvec = embedder(mel)
                dvec_list.append(dvec)
            dvec = torch.stack(dvec_list, dim=0)
            dvec = dvec.detach()
            #mask model
            optimizer.zero_grad()
            mask = model(mixed_mag, dvec)
            output = mixed_mag * mask
            #calculate loss, the paper says it use powerlaw, but we don't do it here
            loss = criterion(output, target_mag)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(len(output))
            loss = loss.item()
            #log
            step += len(output)
            logger.info('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), batch_idx, loss, tracker.rate(),
            tracker.global_rate(), time.asctime()))
            if step % config.train['ckpt_interval'] == 0 :
                model_saver(model,optimizer,pt_dir,epoch)
                logger.info("Saved Checkpoint at Epoch%d,Step%d" % (epoch, step))
            
    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()