Beispiel #1
0
    def optimizer_step(self,
                       epoch,
                       batch_idx,
                       optimizer,
                       optimizer_idx,
                       second_order_closure=None,
                       using_native_amp=None):
        if optimizer_idx == 0:
            if self.trainer.use_tpu and XLA_AVAILABLE:
                xm.optimizer_step(optimizer)
            elif isinstance(optimizer, torch.optim.LBFGS):
                optimizer.step(second_order_closure)
            else:
                optimizer.step()

            # clear gradients
            optimizer.zero_grad()

        elif optimizer_idx == 1:
            pass
    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            "optimizer",
            lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum),
        )
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(
                    device, x, loss.item(), tracker.rate(), tracker.global_rate()
                )
Beispiel #3
0
def train_loop_fn(model, loader, device, context):
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([3,3,3,3,3,5]).to(device))
    
    log_loss = nn.BCEWithLogitsLoss(weight=torch.FloatTensor([1,1,1,1,1,2]).to(device), reduction='none')
    def metric_fn(outputs, target):
        return (log_loss(outputs, target).sum(-1) / log_loss.weight.sum()).mean()
    
    if args.metric_loss: loss_fn = metric_fn
    optimizer = context.getattr_or(
      'optimizer',
      lambda: torch.optim.AdamW(model.parameters(), lr=args.lr,
                                betas=(0.9, 0.999), weight_decay=args.weight_decay) 
    )

    lr_scheduler = context.getattr_or(
        'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
            optimizer,
            scheduler_type='WarmupAndExponentialDecayScheduler',
            scheduler_divisor=args.slr_divisor,
            scheduler_divide_every_n_epochs=args.slr_div_epochs,
            num_warmup_epochs=args.n_warmup,
            min_delta_to_update_lr=args.min_lr,
            num_steps_per_epoch=num_steps_per_epoch))
    
    score = MovingAverage(maxlen=500)
    metric = MovingAverage(maxlen=500)
    model.train()
    for x, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
        if args.model_name == 'inception_v3': output = output.logits
        loss = loss_fn(output, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        score(loss.item())
        metric(metric_fn(output, target).item())
        if x % args.log_steps == 0:
            logging.info('[{}]({:5d}) Moving average loss: {:.5f}, metric: {:.5f}'
                             .format(device, x, score.mean(), metric.mean()))
        if lr_scheduler:
            lr_scheduler.step()
    def train_one_epoch(self, train_loader, e, save_flag):
        self.model.train()

        losses = AverageMeter()
        final_scores = RocAucMeter()
        t = time.time()
        for step, (targets, inputs, attention_masks,
                   ids) in enumerate(train_loader):
            if self.config.verbose:
                if step % self.config.verbose_step == 0:
                    self.log(
                        f'Train Step {step}, loss: ' + \
                        f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                        f'time: {(time.time() - t):.5f}'
                    )

            inputs = inputs.to(self.device, dtype=torch.long)
            attention_masks = attention_masks.to(self.device, dtype=torch.long)
            targets = targets.to(self.device, dtype=torch.float)

            self.optimizer.zero_grad()

            outputs = self.model(inputs, attention_masks)
            loss = self.criterion(outputs, targets)

            batch_size = inputs.size(0)

            final_scores.update(targets, outputs)

            losses.update(loss.detach().item(), batch_size)

            loss.backward()
            xm.optimizer_step(self.optimizer)

            if self.config.step_scheduler:
                self.scheduler.step()

        self.model.eval()
        if save_flag == 1:
            self.save(f'{FILE_NAME}_epoch_{e}.bin')
        return losses, final_scores
Beispiel #5
0
 def train_one_epoch(loader):
     model.train()
     running_loss = 0
     max_idx = 0
     xm.master_print("-" * 40)
     xm.master_print("Step\t|\tTime")
     xm.master_print("-" * 40)
     for idx, (images, targets) in enumerate(loader):
         optimizer.zero_grad()
         y_pred = model(images.float())
         loss = criterion(y_pred, targets)
         running_loss += float(loss)
         loss.backward()
         xm.optimizer_step(optimizer)
         # xm.mark_step() call everystep for grad accum
         max_idx = float(idx)
         if idx % FLAGS["log_steps"] == 0 and idx != 0:
             xm.master_print("({})\t|\t{}".format(
                 idx, time.asctime(time.localtime())))
     xm.master_print("-" * 40)
     return running_loss / (max_idx + 1)
Beispiel #6
0
    def step(self, curr):
        #selects the loss
        if self.schedule_coeff[self.i][0] < curr:
            self.i = self.i + 1
        self.which = self.l - self.i
        for optimizer in self.optimizers:
            optimizer.zero_grad()
        t = time()
        try:
            data = next(self.data)
        except StopIteration:
            self.Data_Generator.reset_generator()
            self.data = self.Data_Generator.next_batch()
            data = next(self.data)
        self.IO_time += time() - t
        t = time()
        loss = self.Network.score(imgL=data[0],
                                  imgR=data[1],
                                  which=self.which,
                                  lp=self.i + 1,
                                  train=True)
        self.forward_time += time() - t
        t = time()
        self.sm += loss[0].detach()
        self.re += loss[2].detach()
        self.ds += loss[1].detach()
        self.em += loss[3].detach()
        l = 0
        for i in loss:
            l += i
        l = l.mul(self.schedule_coeff[self.i][1])
        l.backward()
        if self.tpu:
            for optimizer in self.optimizers:
                xm.optimizer_step(optimizer)  #, barrier=True)
        else:
            for optimizer in self.optimizers:
                optimizer.step()

        self.backward_time += time() - t
def tpu_train_fn(data_loader,
                 model,
                 optimizer,
                 device,
                 num_batches,
                 scheduler=None,
                 loss_fn=None):
    model.train()
    tk0 = tqdm(data_loader, total=len(data_loader), desc="Training")
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        sentiment = d["sentiment"]
        orig_selected = d["orig_selected"]
        orig_tweet = d["orig_tweet"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        offsets = d["offsets"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets_start = targets_start.to(device, dtype=torch.long)
        targets_end = targets_end.to(device, dtype=torch.long)

        model.zero_grad()
        outputs_start, outputs_end = model(
            ids=ids,
            mask=mask,
            token_type_ids=token_type_ids,
        )

        loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
        scheduler.step()
        tk0.set_postfix(loss=loss.item())
Beispiel #8
0
def train_loop_fn(train_loader,
                  args,
                  model,
                  criterion,
                  optimizer,
                  device,
                  scheduler=None):
    model.train()
    criterion.train()
    for i, sample in enumerate(train_loader):
        sample = _prepare_sample(sample, device)
        print(sample["target"].shape, sample["target"].device)
        optimizer.zero_grad()
        _, _, logging_output = criterion(model, sample)
        logging = criterion.aggregate_logging_outputs([logging_output])
        loss = logging["loss"]
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
        if i % args.log_steps == 0:
            xm.master_print('bi={}, loss={:.4f}'.format(i, loss.item()))
            xm.master_print('MEM: {}'.format(psutil.virtual_memory()))
    print('End training: {}'.format(device))
Beispiel #9
0
def train(embedder, model, optimizer, trainloader, writer, logger, epoch, pt_dir,device):
    try:
        tracker = xm.RateTracker()
        criterion = nn.MSELoss()
        model.train()
        step = 0
        for batch_idx, (dvec_mel, target_mag, mixed_mag) in enumerate(trainloader):
            target_mag, mixed_mag = target_mag.to(device), mixed_mag.to(device)

            dvec_list = list()
            for mel in dvec_mel:
                mel = mel.to(device)
                dvec = embedder(mel)
                dvec_list.append(dvec)
            dvec = torch.stack(dvec_list, dim=0)
            dvec = dvec.detach()
            #mask model
            optimizer.zero_grad()
            mask = model(mixed_mag, dvec)
            output = mixed_mag * mask
            #calculate loss, the paper says it use powerlaw, but we don't do it here
            loss = criterion(output, target_mag)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(len(output))
            loss = loss.item()
            #log
            step += len(output)
            logger.info('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), batch_idx, loss, tracker.rate(),
            tracker.global_rate(), time.asctime()))
            if step % config.train['ckpt_interval'] == 0 :
                model_saver(model,optimizer,pt_dir,epoch)
                logger.info("Saved Checkpoint at Epoch%d,Step%d" % (epoch, step))
            
    except Exception as e:
        logger.info("Exiting due to exception: %s" % e)
        traceback.print_exc()
Beispiel #10
0
    def train_one_epoch(self, train_loader):

        self.model.train()
        summary_loss = AverageMeter()
        final_scores = RocAucMeter()
        t = time.time()

        for step, (images, targets) in enumerate(train_loader):

            t0 = time.time()
            batch_size = images.shape[0]
            outputs = self.model(images)

            self.optimizer.zero_grad()
            loss = self.criterion(outputs, targets)
            loss.backward()                         # compute and sum gradients on params
            #torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=global_config.CLIP_GRAD_NORM) 
            
            xm.optimizer_step(self.optimizer)
            if self.config.step_scheduler:
                self.scheduler.step()

            try: 
                final_scores.update(targets, outputs)
            except:
                # xm.master_print("outputs: ", list(outputs.data.cpu().numpy())[:10])
                pass
            summary_loss.update(loss.detach().item(), batch_size)

            if self.config.verbose:
                if step % self.config.verbose_step == 0:

                    t1 = time.time()
                    effNet_lr = np.format_float_scientific(self.optimizer.param_groups[0]['lr'], unique=False, precision=1)
                    head_lr   = np.format_float_scientific(self.optimizer.param_groups[1]['lr'], unique=False, precision=1)
                    xm.master_print(f":::({str(step).rjust(4, ' ')}/{len(train_loader)}) | Loss: {summary_loss.avg:.4f} | AUC: {final_scores.avg:.5f} | LR: {effNet_lr}/{head_lr} | BTime: {t1-t0 :.2f}s | ETime: {int((t1-t0)*(len(train_loader)-step)//60)}m")

        return summary_loss, final_scores
Beispiel #11
0
    def train_fn(epoch, train_dataloader, optimizer, criterion, scheduler,
                 device):
        model.train()

        for batch_idx, batch_data in enumerate(train_dataloader):
            optimizer.zero_grad()

            batch_data = any2device(batch_data, device)
            outputs = model(**batch_data)

            y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE]
            y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE]

            loss = criterion(y_pred, y_true)

            if batch_idx % 100:
                xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}")

            loss.backward()
            xm.optimizer_step(optimizer)

            if scheduler is not None:
                scheduler.step()
Beispiel #12
0
def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    for bi, d in enumerate(data_loader):  # bi --> batch index
        ids = d['ids']
        maks = d['mask']
        segment_ids = d['segment_ids']
        targets = d['targets']

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        segment_ids = segment_ids.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(ids=ids, mask=mask, token_type_ids=segment_ids)
        loss = loss_fn(outputs, targets)
        loss.backward()
        xm.optimizer_step(
            optimizer, barrier=True)  # optimizer.step()'in yerine kullaniyoruz
        if scheduler is not None:
            scheduler.step()
        if bi % 10 == 0:
            print(f"bi={bi}, loss={loss}")
def train():
    net.train()  # enter train mode
    loss_avg = 0.0
    for bx, by in tqdm(train_loader):
        # print(xmetrics.metrics_report())
        bx, by = bx.to(xla_device), by.to(xla_device)
        curr_batch_size = bx.size(0)

        # forward
        logits = net(bx * 2 - 1)

        # backward
        optimizer.zero_grad()
        loss = F.cross_entropy(logits, by)
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)

        scheduler.step()

        # exponential moving average
        loss_avg = loss_avg * 0.9 + float(loss) * 0.1

    state['train_loss'] = loss_avg
Beispiel #14
0
    def loop_fn(model, loader, device, context):
      loss_fn = nn.NLLLoss()
      optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

      for data, target in loader:
        with xu.TimedScope(msg='Training loop: ', printfn=None):
          optimizer.zero_grad()
          output = xu.timed(lambda: model(data), msg='Model: ', printfn=None)
          loss = xu.timed(
              lambda: loss_fn(output, target), msg='Loss: ', printfn=None)
          xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
          xu.timed(
              lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None)
          self.assertLess(loss.cpu().item(), 3.0)
Beispiel #15
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        total_samples_train, correct_train = 0, 0

        # Training and calculating train accuracy and loss
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            train_loss = loss_fn(output, target)
            train_loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(data.shape[0])

            pred_train = output.max(1, keepdim=True)[1]
            correct_train += pred_train.eq(target.view_as(pred_train)).sum().item()
            total_samples_train += data.size()[0]

            scheduler.step()
            if x % 40 == 0:
                print(
                    "[xla:{}]({})\tLoss={:.3f}\tRate={:.2f}\tGlobalRate={:.2f}".format(
                        xm.get_ordinal(),
                        x,
                        train_loss.item(),
                        tracker.rate(),
                        tracker.global_rate(),
                    ),
                    flush=True,
                )

        train_accuracy = 100.0 * correct_train / total_samples_train
        print(
            "[xla:{}] Accuracy={:.2f}%".format(xm.get_ordinal(), train_accuracy),
            flush=True,
        )
        return train_accuracy
Beispiel #16
0
def train_loop_fn(model, loader, device, context):
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean',
                                   pos_weight=torch.FloatTensor([7
                                                                 ]).to(device))
    optimizer = context.getattr_or(
        'optimizer', lambda: torch.optim.AdamW(model.parameters(),
                                               lr=args.lr,
                                               eps=1e-08,
                                               betas=(0.9, 0.999),
                                               weight_decay=args.weight_decay))

    lr_scheduler = context.getattr_or(
        'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
            optimizer,
            scheduler_type='WarmupAndExponentialDecayScheduler',
            scheduler_divisor=args.slr_divisor,
            scheduler_divide_every_n_epochs=args.slr_divide_n_epochs,
            num_warmup_epochs=args.num_warmup_epochs,
            min_delta_to_update_lr=args.num_warmup_epochs,
            num_steps_per_epoch=num_steps_per_epoch))

    score = []
    model.train()
    for x, (data, target) in loader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output['out'], target)
        loss.backward()
        xm.optimizer_step(optimizer)
        score.append(loss.item())
        if (args.log_step) and (x % args.log_step) == 0:
            logging.info('[{}]({}) Loss={:.4f}'.format(device, x, loss.item()))
        if lr_scheduler:
            lr_scheduler.step()

    score = sum(score) / len(score)
    return score
Beispiel #17
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, batch in enumerate(loader):
            # batch = tuple(t.to(self.device) for t in batch)
            loss = model(*batch)  # the last one is label
            #loss = criterion(output, batch[-1])
            loss.backward()
            # xm.optimizer_step(optimizer)
            # optimizer.zero_grad()

            tracker.add(FLAGS.batch_size)
            if (x + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.max_grad_norm)
                # after 梯度累加的基本思想在于,在优化器更新参数前,也就是执行 optimizer.step() 前,进行多次反向传播,是的梯度累计值自动保存在 parameter.grad 中,最后使用累加的梯度进行参数更新。
                xm.optimizer_step(optimizer)
                optimizer.zero_grad()

            if xm.get_ordinal() == 0:
                if x % FLAGS.log_steps == 0:
                    print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
                        xm.get_ordinal(), x, loss.item(), tracker.rate(),
                        tracker.global_rate(), time.asctime()), flush=True)
Beispiel #18
0
def train_iteration(model, optimizer, dataset, train_pairs, qrels):
    model.train()
    total = 0
    total_loss = 0.
    with tqdm('training',
              total=BATCH_SIZE * BATCHES_PER_EPOCH,
              ncols=80,
              desc='train') as pbar:
        for n_iter, record in enumerate(
                data.iter_train_pairs(model, dataset, train_pairs, qrels,
                                      GRAD_ACC_SIZE)):
            # if n_iter > 15:
            # return
            scores = model(record['query_tok'], record['query_mask'],
                           record['doc_tok'], record['doc_mask'])
            count = len(record['query_id']) // 2
            # scores = scores.reshape(count, 2)

            # loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pairwise softmax
            # loss.backward()
            # total_loss += loss.item()
            # total_loss += loss
            total += count

            # if n_iter > 0:
            # print(n_iter, [(record[x].size(), record[x].device) for x in ['query_tok', 'query_mask', 'doc_tok', 'doc_mask']])
            # import torch_xla.debug.metrics as met
            # print(met.metrics_report())

            if total % BATCH_SIZE == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()

            pbar.update(count)
            if total >= BATCH_SIZE * BATCHES_PER_EPOCH:
                return total_loss
Beispiel #19
0
 def train_loop_fn(model, loader, device, context):
   loss_fn = nn.CrossEntropyLoss()
   optimizer = context.getattr_or(
       'optimizer', lambda: optim.SGD(
           model.parameters(),
           lr=FLAGS.lr,
           momentum=FLAGS.momentum,
           weight_decay=1e-4))
   lr_scheduler = context.getattr_or(
       'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
           optimizer,
           scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
           scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
           scheduler_divide_every_n_epochs=getattr(
               FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
           num_steps_per_epoch=num_training_steps_per_epoch,
           summary_writer=writer if xm.is_master_ordinal() else None))
   tracker = xm.RateTracker()
   model.train()
   for x, (data, target) in enumerate(loader):
     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()
     xm.optimizer_step(optimizer)
     tracker.add(FLAGS.batch_size)
     if x % FLAGS.log_steps == 0:
       test_utils.print_training_update(
           device,
           x,
           loss.item(),
           tracker.rate(),
           tracker.global_rate(),
           summary_writer=writer)
     if lr_scheduler:
       lr_scheduler.step()
Beispiel #20
0
    def __optimizer_step(self,
                         *args,
                         closure: Optional[Callable] = None,
                         profiler_name: str = None,
                         **kwargs):
        trainer = self._trainer
        optimizer = self._optimizer
        model = trainer.get_model()

        if trainer.on_tpu:
            with trainer.profiler.profile(profiler_name):
                xm.optimizer_step(optimizer,
                                  optimizer_args={
                                      'closure': closure,
                                      **kwargs
                                  })

        elif trainer.amp_backend is not None:
            trainer.precision_connector.backend.optimizer_step(
                trainer, optimizer, closure)

        else:
            with trainer.profiler.profile(profiler_name):
                optimizer.step(closure=closure, *args, **kwargs)

        accelerator_backend = trainer.accelerator_backend
        if accelerator_backend is not None and accelerator_backend.rpc_enabled:
            if accelerator_backend.ddp_plugin.is_main_rpc_process:
                # Initialize optimizer step on main process
                accelerator_backend.ddp_plugin.worker_optimizer_step(
                    model=model, opt_idx=self._optimizer_idx, *args, **kwargs)

        trainer.train_loop.on_before_zero_grad(self)

        model.optimizer_zero_grad(trainer.current_epoch, trainer.batch_idx,
                                  optimizer, self._optimizer_idx)
    def train(self):
        bar_total = tqdm(range(self.start_epoch, self.end_epoch),
                         desc='Training',
                         leave=False)
        n_samples = len(self.train_loader.sampler)
        for self.epoch in bar_total:
            total_loss = 0
            for data in self.train_loader:
                inputs, labels = data
                inputs, labels = Variable(inputs), Variable(labels)
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                #inputs = inputs.transpose(1, 3)
                y_pred = self.model(inputs)
                loss = self.criterion(y_pred, labels)
                self.optimizer.zero_grad()
                loss.backward()
                if self.tpu:
                    xm.optimizer_step(self.optimizer, barrier=True)
                else:
                    self.optimizer.step()

                total_loss += loss.item()

            train_loss = total_loss / len(self.train_loader)
            bar_total.set_description("Loss: {}".format(train_loss))
            bar_total.refresh()

            if self.epoch % self.summary_write == 0:
                accuracy = self.evaluate()
                self.summary.add_scalar('Train loss', train_loss, self.epoch)
                self.summary.add_scalar('Validation accuracy', accuracy,
                                        self.epoch)
                self.summary.close()

            if self.epoch % self.save_model == 0:
                self.save_checkpoint()
Beispiel #22
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        positions = torch.arange(SEQUENCE_LENGTH).long().view(
            1, SEQUENCE_LENGTH).to(device)
        causal_mask = torch.triu(torch.ones(SEQUENCE_LENGTH,
                                            SEQUENCE_LENGTH,
                                            dtype=torch.uint8,
                                            device=device),
                                 diagonal=1).unsqueeze(0)

        model.train()
        for iteration, batch in enumerate(loader):
            optimizer.zero_grad()
            input = batch[:, :-1].long()
            target = batch[:, 1:].long()
            if not xla_enabled:
                input = input.to(device)
                target = target.to(device)

            if amp_enabled:
                loss = loop_with_amp(model, input, positions, target,
                                     causal_mask, optimizer, xla_enabled,
                                     autocast, scaler)
            else:
                loss = model(input, positions, target, batch_mask=causal_mask)
                loss.backward()
                if xla_enabled:
                    xm.optimizer_step(optimizer)
                else:
                    optimizer.step()
            tracker.add(BATCH_SIZE)
            if iteration % LOG_STEPS == 0:
                print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
                    device, iteration,
                    loss.item() / math.log(2), tracker.rate()))
Beispiel #23
0
 def optimizer_step(self, model: Union["pl.LightningModule", Module],
                    optimizer: Optimizer, optimizer_idx: int,
                    closure: Callable[[], Any], **kwargs: Any) -> None:
     if isinstance(model, pl.LightningModule):
         closure = partial(self._wrap_closure, model, optimizer,
                           optimizer_idx, closure)
     closure_result = xm.optimizer_step(optimizer,
                                        optimizer_args={
                                            "closure": closure,
                                            **kwargs
                                        })
     skipped_backward = closure_result is None
     # in manual optimization, the closure does not return a value
     if isinstance(model, pl.LightningModule
                   ) and model.automatic_optimization and skipped_backward:
         # we lack coverage here so disable this - something to explore if there's demand
         raise MisconfigurationException(
             "Skipping backward by returning `None` from your `training_step` is not implemented for TPUs."
             " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
             " requesting this feature.")
Beispiel #24
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)
Beispiel #25
0
    def _train_one_epoch(self, loader):
        loader_time = .0
        train_time = .0
        curr_time = time.time()

        self.epoch_storage = defaultdict(list)
        for key in ['approx', 'target', 'loss', 'batch_metric']:
            self.epoch_storage[key] = []
        if self.fp16:
            scaler = amp.GradScaler()

        self.model.train()
        if self.progress_bar and self.rank == 0:
            iterator = enumerate(tqdm(loader, desc='train'))
        else:
            iterator = enumerate(loader)

        for batch_i, inputs in iterator:
            loader_time += time.time() - curr_time
            curr_time = time.time()

            self.optimizer.zero_grad()
            batches_done = len(loader) * (self.global_epoch - 1) + batch_i
            inputs = [t.to(self.device) for t in inputs]

            # forward and backward
            if self.fp16:
                with amp.autocast():
                    loss, approx = self.forward_train(self, inputs)
                    self.evaluate_batch(self, inputs, approx)  # evaluation
                loss = loss / self.grad_accumulations
                scaler.scale(loss).backward()
                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.sam:
                        # first step
                        optimizer_state = scaler._per_optimizer_states[id(
                            self.optimizer)]
                        scaler.unscale_(self.optimizer)
                        if not sum(v.item() for v in
                                   optimizer_state["found_inf_per_device"].
                                   values()):
                            self.optimizer.first_step(zero_grad=True)
                        optimizer_state["stage"] = 2
                        scaler.update()
                        # second step
                        with amp.autocast():
                            loss2, _ = self.forward_train(self, inputs)
                        scaler.scale(loss2).backward()
                        scaler.unscale_(self.optimizer)
                        if not sum(v.item() for v in
                                   optimizer_state["found_inf_per_device"].
                                   values()):
                            self.optimizer.second_step(zero_grad=True)
                        optimizer_state["stage"] = 2
                        scaler.update()
                    else:
                        scaler.step(self.optimizer)
                        scaler.update()
            else:
                loss, approx = self.forward_train(self, inputs)
                self.evaluate_batch(self, inputs, approx)  # evaluation
                loss = loss / self.grad_accumulations
                loss.backward()
                if (batch_i + 1) % self.grad_accumulations == 0:
                    if self.xla:
                        if self.sam:
                            raise RuntimeError(
                                'SAM optimizer on XLA device is not available.'
                            )
                        else:
                            xm.optimizer_step(self.optimizer, barrier=True)
                    else:
                        if self.sam:
                            self.optimizer.first_step(zero_grad=True)
                            loss2, _ = self.forward_train(self, inputs)
                            loss2.backward()
                            self.optimizer.second_step(zero_grad=True)
                        else:
                            self.optimizer.step()
                    if self.batch_scheduler:
                        self.scheduler.step()

            if self.parallel == 'ddp' and self.ddp_average_loss:
                if self.xla:
                    loss_batch = xm.all_gather(
                        loss.detach().clone().view(1)).mean().item()
                else:
                    loss_batch = comm.gather_tensor(
                        loss.detach().clone().view(1)).mean().item()
            else:  # Use loss on device: 0
                loss_batch = loss.item()

            # logging
            learning_rate = [
                param_group['lr']
                for param_group in self.optimizer.param_groups
            ]
            logs = [('batch_train_loss', loss_batch),
                    ('batch_train_lr', learning_rate)]
            if len(self.epoch_storage['batch_metric']) > 0:
                metric = self.epoch_storage['batch_metric'][-1]
                logs.append(('batch_valid_mertric', metric))
            self.tb_logger.list_of_scalars_summary(logs, batches_done)
            self.epoch_storage['loss'].append(loss_batch)

            train_time += time.time() - curr_time
            curr_time = time.time()

        if self.debug and self.rank == 0:
            self.logger(
                f'loader: {loader_time:.1f} s | train: {train_time:.1f} s')

        for key, val in self.epoch_storage.items():
            if len(val) > 0:
                if isinstance(val[0], torch.Tensor):
                    self.epoch_storage[key] = torch.cat(val)
                else:
                    self.epoch_storage[key] = torch.tensor(val).to(self.device)

        loss_total = self.epoch_storage['loss'].mean().item()

        if self.parallel == 'ddp':
            # gather tensors
            for key, val in self.epoch_storage.items():
                if len(val) > 0:
                    if self.xla:
                        self.epoch_storage[key] = xm.all_gather(val)
                    else:
                        self.epoch_storage[key] = comm.gather_tensor(val)

            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        else:
            metric_total, monitor_metrics_total = self.evaluate_epoch(self)

        if metric_total is None:
            metric_total = loss_total

        # logging
        logs = [
            ('epoch_train_loss', loss_total),
            ('epoch_train_metric', metric_total),
        ]
        self.tb_logger.list_of_scalars_summary(logs, self.global_epoch)
        return loss_total, metric_total, monitor_metrics_total
Beispiel #26
0
def train(epoch):
    logger.info('\nEpoch: %d' % epoch)
    net.train()
    train_loss = AverageMeter(100)
    acc = AverageMeter(100)
    batch_time = AverageMeter()
    reg_loss = AverageMeter(100)
    train_loss_avg = 0
    correct = 0
    total = 0
    mean = 0
    var = 0
    lambda_ = 0
    xi_ = 0

    for m in net.modules():
        if isinstance(m, Constraint_Norm):
            m.reset_norm_statistics()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        start = time.time()
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        else:
            inputs = inputs.to(device)
            targets = targets.to(device)
        bsz = inputs.size(0)

        outputs = net(inputs)
        if args.optim_loss == 'mse':
            targets = targets.float()
        loss = criterion(outputs, targets)

        # constraint loss
        weight_mean = 0
        weight_var = 0
        weight_mean_abs = 0
        weight_var_abs = 0
        for m in net.modules():
            if isinstance(m, Constraint_Lagrangian):
                weight_mean_, weight_var_ = m.get_weight_mean_var()
                weight_mean_abs_, weight_var_abs_ = m.get_weight_mean_var_abs()
                weight_mean += weight_mean_
                weight_var += weight_var_
                weight_mean_abs += weight_mean_abs_
                weight_var_abs += weight_var_abs_

        constraint_loss = args.lambda_weight_mean * weight_mean + weight_var
        constraint_loss = args.lambda_constraint_weight * constraint_loss
        weight_mean_abs = args.lambda_constraint_weight * weight_mean_abs
        weight_var_abs = args.lambda_constraint_weight * weight_var_abs

        # optimize constraint loss

        train_loss.update(loss.item())
        train_loss_avg += loss.item()
        loss += constraint_loss

        # optimize
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct_idx = predicted.eq(targets.data).cpu().sum().float()
        correct += correct_idx
        acc.update(100. * correct_idx / float(targets.size(0)))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
        if use_cuda:
            optimizer.step()
        else:
            xm.optimizer_step(optimizer, barrier=True)
        batch_time.update(time.time() - start)
        remain_iter = args.epoch * len(trainloader) - (
            epoch * len(trainloader) + batch_idx)
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        if (batch_idx + 1) % args.print_freq == 0:
            logger.info('Train: [{0}][{1}/{2}]\t'
                        'Loss {train_loss.avg:.3f}\t'
                        'acc {acc.avg:.3f}\t'
                        'correct: [{correct}/{total}]\t'
                        'Constraint mean {corat_mean:.4f}\t'
                        'Constraint var {corat_var:.4f}\t'
                        'Constraint lambda {corat_lambda:.4f}\t'
                        'Constraint xi {corat_xi:.4f}\t'
                        'mean {mean:.4f}\t'
                        'var {var:.4f}\t'
                        'remain_time: {remain_time}'.format(
                            epoch,
                            batch_idx,
                            len(trainloader),
                            train_loss=train_loss,
                            corat_mean=-1 * weight_mean.item(),
                            corat_var=-1 * weight_var.item(),
                            corat_lambda=lambda_,
                            corat_xi=xi_,
                            mean=mean,
                            var=var,
                            acc=acc,
                            correct=int(correct),
                            total=total,
                            remain_time=remain_time,
                        ))

        if (batch_idx + 1) % args.print_freq == 0:
            mean = []
            var = []
            for m in net.modules():
                if isinstance(m, Constraint_Norm):
                    mean_, var_ = m.get_mean_var()
                    mean.append(mean_.abs())
                    var.append(var_.abs())
            mean = torch.mean(torch.stack(mean))
            var = torch.mean(torch.stack(var))
            curr_idx = epoch * len(trainloader) + batch_idx
            tb_logger.add_scalar("train/train_loss", train_loss.avg, curr_idx)
            tb_logger.add_scalar("train/train_acc", acc.avg, curr_idx)
            tb_logger.add_scalar("train/norm_mean(abs)", mean, curr_idx)
            tb_logger.add_scalar("train/norm_var-1(abs)", var, curr_idx)
            tb_logger.add_scalar("train/weight_mean(abs)",
                                 weight_mean_abs.item(), curr_idx)
            tb_logger.add_scalar("train/weight_var-1(abs)",
                                 weight_var_abs.item(), curr_idx)
            tb_logger.add_scalar("train/constraint_loss_mean",
                                 -1 * weight_mean.item(), curr_idx)
            tb_logger.add_scalar("train/constraint_loss_var",
                                 -1 * weight_var.item(), curr_idx)

            # get the constraint weight
            lambda_ = []
            xi_ = []
            for m in net.modules():
                if isinstance(m, Constraint_Lagrangian):
                    lambda_.append(m.lambda_.data.abs().mean())
                    xi_.append(m.xi_.data.abs().mean())
            lambda_ = torch.max(torch.stack(lambda_))
            xi_ = torch.max(torch.stack(xi_))
            tb_logger.add_scalar("train/constraint_lambda_", lambda_.item(),
                                 curr_idx)
            tb_logger.add_scalar("train/constraint_xi_", xi_.item(), curr_idx)

    tb_logger.add_scalar("train/train_loss_epoch",
                         train_loss_avg / len(trainloader), epoch)
    tb_logger.add_scalar("train/train_acc_epoch", 100. * correct / total,
                         epoch)
    wandb.log({"train/acc_epoch": 100. * correct / total}, step=epoch)
    wandb.log({"train/loss_epoch": train_loss_avg / len(trainloader)},
              step=epoch)
    wandb.log({"train/norm_mean(abs)": mean.item()}, step=epoch)
    wandb.log({"train/norm_var-1(abs)": var.item()}, step=epoch)
    wandb.log({"train/weight_mean(abs)": weight_mean_abs.item()}, step=epoch)
    wandb.log({"train/weight_var-1(abs)": weight_var_abs.item()}, step=epoch)
    wandb.log({"train/constraint_loss_mean": -1 * weight_mean.item()},
              step=epoch)
    wandb.log({"train/constraint_loss_var": -1 * weight_var.item()},
              step=epoch)
    logger.info("epoch: {} acc: {}, loss: {}".format(
        epoch, 100. * correct / total, train_loss_avg / len(trainloader)))

    for m in net.modules():
        if isinstance(m, Constraint_Norm):
            m.reset_norm_statistics()
    return (train_loss.avg, reg_loss.avg, 100. * correct / total)
Beispiel #27
0
 def step(self, closure: Optional[Callable] = None) -> None:
     xm.optimizer_step(self.wrapped_optimizer, barrier=True)
Beispiel #28
0
def train_model(model,
                criterion,
                optimizer,
                dataloaders,
                dataset_sizes,
                num_epochs=10,
                model_type="VS",
                weight_file="best_modelweights.dat",
                L1_loss=0,
                suppress_log=False,
                hyperparam_search=False,
                use_tpu=False,
                multigpu=False,
                tensorboard=True):

    if use_tpu:
        print(
            "using TPU acceleration, model and optimizer should already be loaded onto tpu device"
        )
        device = xm.xla_device()
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            print("using GPU acceleration")
        if multigpu and torch.cuda.device_count() > 1:
            print("multigpu enabled")
            model = nn.DataParallel(model)
            model = model.to(device, dtype=torch.float)
        else:
            model = model.to(device, dtype=torch.float)

    since = time.time()
    best_loss = np.Inf

    #train_losses = np.zeros(num_epochs*dataset_sizes['train'])
    #val_losses = np.zeros(num_epochs*dataset_sizes['val'])
    train_losses = np.zeros(num_epochs * len(dataloaders['train']))
    val_losses = np.zeros(num_epochs * len(dataloaders['val']))

    it_val = 0
    it_train = 0

    if tensorboard:
        writer = SummaryWriter()

    for epoch in range(num_epochs):
        if suppress_log == False:
            print('Epoch {}/{}'.format(epoch + 1, num_epochs))
            print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode
                # initialize the predictions
                if dataloaders[phase].dataset.include_torque:
                    predictions = np.empty((0, 6))
                else:
                    predictions = np.empty((0, 3))

            running_loss = 0.0

            # Iterate over data.
            batch_size = 0
            it = 1

            for inputs, aug_inputs, labels in dataloaders[phase]:
                # zero the parameter gradients
                optimizer.zero_grad()

                if model_type != "S":
                    inputs = inputs.to(device, dtype=torch.float)

                if (model_type != "V") or (model_type != "V_RNN"):
                    aug_inputs = aug_inputs.to(device, dtype=torch.float)

                labels = labels.to(device, dtype=torch.float)

                # forward
                # track history if only in train
                if phase == 'train':
                    torch.set_grad_enabled(True)

                    if (model_type == "V") or (model_type == "V_RNN"):
                        outputs = model(inputs)
                    elif model_type == "VS":
                        outputs = model(inputs, aug_inputs)
                    else:
                        outputs = model(aug_inputs)

                    loss = criterion(outputs, labels)

                    if L1_loss:
                        L1 = 0
                        for param in model.parameters():
                            if param.requires_grad:
                                L1 += L1_loss * torch.sum(torch.abs(param))
                        loss = loss + L1

                    if multigpu:
                        loss.mean().backward()
                    else:
                        loss.backward()
                    if use_tpu:
                        xm.optimizer_step(optimizer, barrier=True)
                    else:
                        optimizer.step()
                else:
                    torch.set_grad_enabled(False)

                    if (model_type == "V") or (model_type == "V_RNN"):
                        outputs = model(inputs)
                    elif model_type == "VS":
                        outputs = model(inputs, aug_inputs)
                    else:
                        outputs = model(aug_inputs)

                    loss = criterion(outputs, labels)
                    predictions = np.vstack(
                        (predictions, outputs.cpu().detach().numpy()))

                # statistics
                running_loss += loss.item(
                )  #* inputs.size(0) # multiply by the number of elements to get back the total loss, usually the loss function outputs the mean
                batch_size += inputs.size(0)
                avg_loss = running_loss / batch_size

                if phase == 'train':
                    train_losses[it_train] = avg_loss
                    if tensorboard:
                        writer.add_scalar('Loss/train', avg_loss, it_train)
                    it_train += 1
                else:
                    val_losses[it_val] = avg_loss
                    if tensorboard:
                        writer.add_scalar('Loss/val', avg_loss, it_val)
                    it_val += 1

                if it % 100 == 0 and suppress_log == False:
                    print('average loss for batch ' + str(it) + ' : ' +
                          str(avg_loss))

                it += 1

            epoch_loss = running_loss / dataset_sizes[
                phase]  #divide by the total size of our dataset to get the mean loss per instance

            if tensorboard:
                if phase == "train":
                    writer.add_scalar('ELoss/train', epoch_loss, epoch)
                if phase == "val":
                    writer.add_scalar('ELoss/val', epoch_loss, epoch)

            if suppress_log == False:
                print('{} Loss: {:.4f}'.format(phase, epoch_loss))

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:

                if hyperparam_search == False:
                    print('Saving model... current loss:' +
                          str(round(epoch_loss, 5)) + ' < best loss: ' +
                          str(round(best_loss, 5)))
                    print("Backing up the model")
                    temp_file = open(weight_file, "wb")
                    torch.save(model.state_dict(), temp_file)
                    if tensorboard:
                        fig, ax = plt.subplots(3,
                                               1,
                                               sharex=True,
                                               figsize=(50, 10))
                        plt.ioff()
                        for f_ax in range(3):
                            ax[f_ax].plot(
                                dataloaders[phase].dataset.label_array[:,
                                                                       f_ax +
                                                                       1])
                            ax[f_ax].plot(predictions[:, f_ax], linewidth=1)
                        writer.add_figure('valPred/figure',
                                          fig,
                                          global_step=epoch,
                                          close=True)

                else:
                    print('current loss:' + str(round(epoch_loss, 5)) +
                          ' < best loss: ' + str(round(best_loss, 5)))

                best_loss = epoch_loss

        if suppress_log == False:
            time_elapsed = time.time() - since
            print('Epoch runtime {:.0f}m {:.0f}s'.format(
                time_elapsed // 60, time_elapsed % 60))
            print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    if hyperparam_search == False:
        temp_file.close()
        temp_file = open(weight_file, "rb")
        model.load_state_dict(torch.load(temp_file))

    return model, train_losses, val_losses, best_loss
Beispiel #29
0
def train(args, train_dataset, model, tokenizer, disable_logging=False):
    """ Train the model """
    if xm.is_master_ordinal():
        # Only master writes to Tensorboard
        tb_writer = SummaryWriter(args.tensorboard_logdir)

    train_sampler = get_sampler(train_dataset)
    dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total,
    )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(dataloader) * args.train_batch_size)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per TPU core = %d", args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        (args.train_batch_size * args.gradient_accumulation_steps * xm.xrt_world_size()),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    loss = None
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=disable_logging)
    set_seed(args.seed)  # Added here for reproductibility (even between python 2 and 3)
    for epoch in train_iterator:
        # tpu-comment: Get TPU parallel loader which sends data to TPU in background.
        train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging)
        for step, batch in enumerate(epoch_iterator):

            # Save model checkpoint.
            if args.save_steps > 0 and global_step % args.save_steps == 0:
                output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                logger.info("Saving model checkpoint to %s", output_dir)

                if xm.is_master_ordinal():
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))

                # Barrier to wait for saving checkpoint.
                xm.rendezvous("mid_training_checkpoint")
                # model.save_pretrained needs to be called by all ordinals
                model.save_pretrained(output_dir)

            model.train()
            inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
            if args.model_type != "distilbert":
                # XLM, DistilBERT and RoBERTa don't use segment_ids
                inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                xm.optimizer_step(optimizer)
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics.
                    results = {}
                    if args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer, disable_logging=disable_logging)
                    loss_scalar = loss.item()
                    logger.info(
                        "global_step: {global_step}, lr: {lr:.6f}, loss: {loss:.3f}".format(
                            global_step=global_step, lr=scheduler.get_lr()[0], loss=loss_scalar
                        )
                    )
                    if xm.is_master_ordinal():
                        # tpu-comment: All values must be in CPU and not on TPU device
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar("loss", loss_scalar, global_step)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if xm.is_master_ordinal():
        tb_writer.close()
    return global_step, loss.item()
Beispiel #30
0
    def train(self, data_loader):
        losses = AverageMeter()
        self.model.train()
        print_idx = int(len(data_loader) * self.tpu_print / 100)
        if self.accumulation_steps > 1:
            self.optimizer.zero_grad()
        if self.use_tpu:
            para_loader = pl.ParallelLoader(data_loader, [self.device])
            tk0 = para_loader.per_device_loader(self.device)
        else:
            tk0 = tqdm(data_loader, total=len(data_loader))

        for b_idx, data in enumerate(tk0):
            if self.accumulation_steps == 1 and b_idx == 0:
                self.optimizer.zero_grad()

            if self.model_fn is None:
                for key, value in data.items():
                    data[key] = value.to(self.device)
                _, loss = self.model(**data)
            else:
                if self.fp16:
                    with amp.autocast():
                        loss = self.model_fn(data, self.device, self.model)
                else:
                    loss = self.model_fn(data, self.device, self.model)

            if not self.use_tpu:
                with torch.set_grad_enabled(True):
                    if self.use_mean_loss:
                        loss = loss.mean()

                    if self.fp16:
                        self.scaler.scale(loss).backward()
                    else:
                        loss.backward()

                    if (b_idx + 1) % self.accumulation_steps == 0:
                        if self.fp16:
                            self.scaler.step(self.optimizer)
                        else:
                            self.optimizer.step()

                        if self.scheduler is not None:
                            self.scheduler.step()

                        if b_idx > 0:
                            self.optimizer.zero_grad()

                    if self.fp16:
                        self.scaler.update()
            else:
                loss.backward()
                xm.optimizer_step(self.optimizer)
                if self.scheduler is not None:
                    self.scheduler.step()
                if b_idx > 0:
                    self.optimizer.zero_grad()
            if self.use_tpu:
                reduced_loss = xm.mesh_reduce("loss_reduce", loss, reduce_fn)
                losses.update(reduced_loss.item(), data_loader.batch_size)
            else:
                losses.update(loss.item(), data_loader.batch_size)

            if not self.use_tpu:
                tk0.set_postfix(loss=losses.avg)
            else:
                if b_idx % print_idx == 0 or b_idx == len(data_loader):
                    xm.master_print(
                        f"{datetime.datetime.now()}: Batch {b_idx} / {len(data_loader)}, loss={losses.avg}"
                    )
        if not self.use_tpu:
            tk0.close()
        return losses.avg