Ejemplo n.º 1
0
def train(
    model: GP,
    train_x: torch.Tensor,
    train_y: torch.Tensor,
    num_iters: int,
    lr: float = 0.1,
    show_progress: bool = True,
):
    """Trains the provided model by maximising the marginal likelihood."""
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    mll = gp.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

    loss = 0
    iterator = (
        tqdm(range(num_iters), desc="Epoch") if show_progress else range(num_iters)
    )

    for _ in iterator:
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()
        if show_progress:
            iterator.set_postfix(loss=loss.item())

    return loss.detach().cpu().item()
Ejemplo n.º 2
0
def train(path: str, epochs: int = 3) -> None:
    LR = 1e-3
    DECAY = 1e-4

    dataset = MNISTDataset()
    loader = MNISTLoader()

    model = superdupermodel().cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = AdamW(model.parameters(), lr=LR, weight_decay=DECAY)

    for epoch in tqdm(range(epochs), desc="Epoch", position=0):
        total_loss = 0.0
        total_acc = 0.0

        model = model.train()
        with tqdm(loader.train, "Train", position=1) as pbar:
            for inputs, labels in pbar:
                inputs, labels = inputs.cuda(), labels.cuda()
                optimizer.zero_grad()

                out = model(inputs)
                loss = criterion(out, labels)
                acc = (torch.argmax(torch.softmax(out, 1), 1) == labels).sum()

                loss.backward()
                optimizer.step()

                total_loss += loss.item() / len(loader.train)
                total_acc += acc.item() / len(dataset.train)

                pbar.set_postfix(
                    loss=f"{total_loss:.2e}",
                    acc=f"{total_acc * 100:.2f}%",
                )

        with torch.no_grad():
            total_loss = 0.0
            total_acc = 0.0

            model = model.eval()
            with tqdm(loader.test, "Valid", position=2) as pbar:
                for inputs, labels in pbar:
                    inputs, labels = inputs.cuda(), labels.cuda()

                    out = model(inputs)
                    loss = criterion(out, labels)
                    acc = (torch.argmax(torch.softmax(out, 1),
                                        1) == labels).sum()

                    total_loss += loss.item() / len(loader.test)
                    total_acc += acc.item() / len(dataset.test)

                    pbar.set_postfix(
                        loss=f"{total_loss:.2e}",
                        acc=f"{total_acc * 100:.2f}%",
                    )

    torch.save(model.state_dict(), f"{path}.pth")
Ejemplo n.º 3
0
    def decompose(self, conv, pw, dw, lr=0.001, steps=600):
        """
        GEP decompose standard convolution kernel
        
        :param conv: standard convolution kernel
        :param pw: decomposed pointwise convolution kernel
        :param dw: decomposed depthwise convolution kernel
        :param lr: learning rate
        :param steps: training steps for decomposing
        """

        conv.requires_grad = False
        pw.requires_grad = True
        dw.requires_grad = True

        criterion = nn.MSELoss()
        optimizer = AdamW({pw, dw}, lr=lr)
        st = time.time()
        for s in range(steps):
            if steps in (400, 700):
                lr = lr / 10
                optimizer = AdamW({pw, dw}, lr=lr)
            optimizer.zero_grad()
            kernel_pred = pw.cuda() * dw.cuda()
            loss = criterion(kernel_pred, conv.cuda())
            loss.backward()
            optimizer.step()
            if s % 100 == 99:
                print('loss = %f, time = %d' % (loss, (time.time() - st)))
                st = time.time()
Ejemplo n.º 4
0
def train(path: str, epochs: int = 3) -> None:
    dataset = MNISTDataset()
    loader = MNISTLoader()

    model = LeNet5(n_classes=10).cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = AdamW(model.parameters(), lr=1e-3)

    for epoch in tqdm(range(epochs), desc="Epoch"):
        model.train()
        total_loss, total_acc = 0, 0

        pbar = tqdm(loader.train, desc="Train")
        for img, label in pbar:
            img, label = img.cuda(), label.cuda()
            optimizer.zero_grad()

            preds = model(img)
            loss = criterion(preds, label)
            acc = (torch.argmax(torch.softmax(
                preds,
                dim=1,
            ), dim=1) == label).sum()

            loss.backward()
            optimizer.step()

            total_loss += loss.item() / len(loader.train)
            total_acc += acc.item() / len(dataset.train)

            pbar.set_postfix(
                loss=f"{total_loss:.2e}",
                acc=f"{total_acc * 100:.2f}%",
            )

        model.eval()
        total_loss, total_acc = 0, 0

        pbar = tqdm(loader.test, desc="Test")
        for img, label in pbar:
            img, label = img.cuda(), label.cuda()

            preds = model(img)
            loss = criterion(preds, label)
            acc = (torch.argmax(torch.softmax(
                preds,
                dim=1,
            ), dim=1) == label).sum()

            total_loss += loss.item() / len(loader.test)
            total_acc += acc.item() / len(dataset.test)

            pbar.set_postfix(
                loss=f"{total_loss:.2e}",
                acc=f"{total_acc * 100:.2f}%",
            )

    torch.save(model.state_dict(), f"{path}.pth")
Ejemplo n.º 5
0
class Scheduler:
    def __init__(self, model, args):
        super(Scheduler, self).__init__()
        self.loss = Loss()
        self.optimizer = AdamW(model.parameters(),
                               lr=args.lr,
                               betas=(args.adam_beta1, args.adam_beta2),
                               weight_decay=args.adam_weight_decay)
        self.warm_up = args.warm_up
        self.curr_step = 0
        self.init_lr = args.lr
        self.curr_loss = None

    def __call__(self, out_mask_lm, out_nsp, target):

        mask_pos, mask_label, nsp_label = target
        mask_pos = mask_pos.unsqueeze(-1).expand(mask_pos.size(0),
                                                 mask_pos.size(1),
                                                 out_mask_lm.size(-1))
        out_mask_lm = torch.gather(out_mask_lm, 1, mask_pos)
        nsp_label = nsp_label.long()

        # calculate loss
        loss_nsp = self.loss(out_nsp, nsp_label)
        loss_mask_lm = self.loss(out_mask_lm.transpose(1, 2), mask_label)

        self.curr_loss = loss_mask_lm + loss_nsp

        # calculate acc
        pred_mask_lm = out_mask_lm[:, :, :].max(dim=-1)[1]
        pred_nsp_lm = out_nsp[:, :].max(dim=-1)[1]
        mask_lm_acc = pred_mask_lm.eq(mask_label).sum() / len(
            pred_mask_lm.view(-1))
        nsp_acc = pred_nsp_lm.eq(nsp_label).sum() / len(pred_nsp_lm.view(-1))

        return self.curr_loss.data, mask_lm_acc, nsp_acc

    def step(self, epoch):
        self.curr_loss.backward()
        self._update(epoch)
        self.optimizer.step()
        self.optimizer.zero_grad()

    def _update(self, epoch):
        self.curr_step = epoch
        lr = self.init_lr * self._lr_scale()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def _lr_scale(self):

        # if self.curr_step < self.warm_up:
        #      return 1
        # else:
        #     return 2 ** -((self.curr_step - self.warm_up) // 35)

        return 1
Ejemplo n.º 6
0
def coteaching(train_xs, train_ys, test_xs, test_ys):
    train_xs = np.moveaxis(train_xs, 3, 1)
    test_xs = np.moveaxis(test_xs, 3, 1)

    batch_size = 1024
    train_loader = DataLoader(TensorDataset(FloatTensor(train_xs),
                                            LongTensor(train_ys)),
                              batch_size=batch_size)

    n_epoch, forget_rate = 100, 0.2
    rate_schedule = np.ones(n_epoch) * forget_rate
    rate_schedule[0] = 0.0

    device = 'cuda:0'
    model1 = models.resnet18(num_classes=2).to(device)
    optim1 = AdamW(model1.parameters())
    model2 = models.resnet18(num_classes=2).to(device)
    optim2 = AdamW(model2.parameters())

    for epoch in range(1, n_epoch):
        iters, acc1, acc2 = 0, 0, 0
        for (images, labels) in train_loader:
            images = Variable(images).to(device)
            labels = Variable(labels).to(device)

            iters += 1
            logits1 = model1(images)
            acc1 += accuracy(logits1, labels, batch_size)
            logits2 = model2(images)
            acc2 += accuracy(logits2, labels, batch_size)

            loss_1, loss_2 = loss_coteaching(logits1, logits2, labels,
                                             rate_schedule[epoch])
            optim1.zero_grad()
            loss_1.backward()
            optim1.step()
            optim2.zero_grad()
            loss_2.backward()
            optim2.step()

        printr('Coteaching: Epoch {}: acc1 {:.4f}, acc2 {:.4f}'.format(
            epoch, acc1 / iters, acc2 / iters))
    printr('')
    test_loader = DataLoader(TensorDataset(FloatTensor(test_xs),
                                           LongTensor(test_ys)),
                             batch_size=1024)

    def _eval(model):
        total, correct = 0, 0
        model.eval()
        for images, labels in test_loader:
            _, preds = max(F.softmax(model(images.cuda()), dim=1).data, 1)
            total += len(labels)
            correct += int((preds.cpu() == labels).sum())
        return correct / total

    return (_eval(model1) + _eval(model2)) / 2
Ejemplo n.º 7
0
def train(
    path: str,
    save_all: bool,
    epochs: int = 3,
) -> None:
    dataset = MNISTDataset()
    loader = MNISTLoader()

    # model = Model()
    model = LeNet5(10)
    criterion = nn.CrossEntropyLoss()
    optim = AdamW(model.parameters(), lr=1e-3)

    best_acc = 0.0
    acc = 0.0

    for epoch in tqdm(range(epochs), desc="Epoch"):
        model.train()
        with tqdm(loader.trainloader, desc="Train") as pbar:
            total_loss = 0.0
            acc = 0.0
            for img, label in pbar:
                optim.zero_grad()

                output = model(img)
                loss = criterion(output, label)
                loss.backward()
                optim.step()

                total_loss += loss.item() / len(loader.trainloader)
                acc += (torch.argmax(output, dim=1)
                        == label).sum().item() / len(dataset.trainset)
                pbar.set_postfix(loss=total_loss, acc=f"{acc * 100:.2f}%")

        model.eval()
        with tqdm(loader.validloader, desc="Valid") as pbar:
            total_loss = 0.0
            acc = 0.0
            with torch.no_grad():
                for img, label in pbar:

                    output = model(img)
                    loss = criterion(output, label)

                    total_loss += loss.item() / len(loader.validloader)
                    acc += (torch.argmax(output, dim=1)
                            == label).sum().item() / len(dataset.validset)
                    pbar.set_postfix(loss=total_loss, acc=f"{acc * 100:.2f}%")

        if acc > best_acc:
            torch.save(model.state_dict(), f"{path}/best.pt")
            tqdm.write("saved best")
            best_acc = acc

        if save_all:
            torch.save(model.state_dict(), f"{path}/mnist_{epoch+1:02d}.pt")
    def test_memorize_minibatch(self):
        for db_name in self.db_names:
            db_info = get_db_info(db_name)
            train_data, val_data, _ = get_train_val_test_datasets(
                dataset_name=db_name,
                train_test_split='use_full_train',
                encoders=dict(CATEGORICAL='CategoricalOrdinalEnc',
                              SCALAR='ScalarRobustScalerEnc',
                              DATETIME='DatetimeScalarEnc',
                              LATLONG='LatLongScalarEnc',
                              TEXT='TextSummaryScalarEnc'),
            )
            train_loader = get_dataloader(
                dataset=train_data,
                batch_size=256,
                sampler_class_name='SequentialSampler',
                num_workers=0,
                max_nodes_per_graph=False)

            writer = DummyWriter()
            model = GCN(writer,
                        db_info=db_info,
                        hidden_dim=256,
                        n_init_layers=3,
                        activation_class_name='SELU',
                        activation_class_kwargs={},
                        loss_class_kwargs={},
                        loss_class_name='CrossEntropyLoss',
                        p_dropout=0.0,
                        drop_whole_embeddings=True,
                        n_layers=3,
                        readout_class_name='AvgPooling',
                        readout_kwargs={})
            if torch.cuda.is_available():
                model.cuda()
                model.device = torch.device('cuda:0')
            else:
                model.device = torch.device('cpu')
            model.train()
            optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.0)

            bdgl, features, label = next(iter(train_loader))
            recursive_to((bdgl, features, label), model.device)
            for _ in tqdm(range(200)):
                optimizer.zero_grad()
                output = model(bdgl, features)
                loss = model.loss_fxn(output, label)
                if loss < 1e-4:
                    break
                loss.backward()
                optimizer.step()
            else:
                tqdm.write(f'Loss: {loss}')
                self.fail("Didn't memorize minibatch")
Ejemplo n.º 9
0
    def decompose_rank(self, kernel, lr=5e-3, steps=600):
        """
        GEP decompose standard convolution kernel with different rank
        
        :param conv: standard convolution kernel
        :param lr: learning rate
        :param steps: training steps for decomposing
        """

        kernel.requires_grad = False
        param = {self.dw0.weight, self.pw0.weight}
        for i in range(self.rank):
            getattr(self, 'pw' + str(i)).weight.requires_grad = True
            getattr(self, 'dw' + str(i)).weight.requires_grad = True
            if i != 0:
                param.add(getattr(self, 'pw' + str(i)).weight)
                param.add(getattr(self, 'pw' + str(i)).weight)

        criterion = nn.MSELoss()
        optimizer = AdamW(param, lr=lr)
        st = time.time()
        for s in range(steps):
            if steps in (400, 700):
                lr = lr / 10
                optimizer = AdamW(param, lr=lr)
            optimizer.zero_grad()
            for i in range(self.rank):
                if i == 0:
                    kernel_pred = getattr(self, \
                            'pw' + str(i)).weight.cuda() * \
                getattr(self, 'dw' + str(i)).weight.cuda()
                else:
                    kernel_pred += getattr(self, \
                            'pw' + str(i)).weight.cuda() * getattr(self, \
                            'dw' + str(i)).weight.cuda()
            loss = criterion(kernel_pred, kernel.cuda())
            loss.backward()
            optimizer.step()
            if s % 100 == 99:
                print('step %d: loss = %f, time = %d' % ((s + 1), loss,
                                                         (time.time() - st)))
                st = time.time()
    def train_deembeders(self, tuples: List[Tuple[torch.Tensor, PDGEmbedder,
                                                  PDGDeembedder]],
                         epochs: int) -> List[float]:

        acc_list = []

        for tuple in tuples:
            lab_data, embedder, deembedder = tuple

            deembed_optimizer = AdamW(deembedder.parameters(), lr=1e-4)
            deemb_loss = MSELoss()

            num_classes = len(lab_data)
            real_one_hot = func.one_hot(lab_data,
                                        num_classes=num_classes).float()

            for param in embedder.parameters():
                param.requires_grad = False

            gen_one_hot = None
            for i in range(epochs):
                deembed_optimizer.zero_grad()
                embed = embedder(lab_data)
                gen_one_hot = deembedder(embed)
                err_deemb = deemb_loss(real_one_hot, gen_one_hot)
                err_deemb.backward()
                deembed_optimizer.step()

            acc = 0
            gen_one_hot = (gen_one_hot > .5).int()

            diffs = torch.eq(real_one_hot, gen_one_hot).all(dim=1).int()
            size = len(diffs)
            acc += diffs.int().sum().float()
            acc /= size

            for param in embedder.parameters():
                param.requires_grad = True

            acc_list.append(acc)

        return acc_list
Ejemplo n.º 11
0
def main(
  data_dir,
  save_path,
  batch_size,
  n_workers,
  valid_steps,
  warmup_steps,
  total_steps,
  save_steps,
):
  """Main function."""
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"[Info]: Use {device} now!")

  train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)
  train_iterator = iter(train_loader)
  print(f"[Info]: Finish loading data!",flush = True)

  model = Classifier(n_spks=speaker_num).to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = AdamW(model.parameters(), lr=1e-3)
  scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
  print(f"[Info]: Finish creating model!",flush = True)

  best_accuracy = -1.0
  best_state_dict = None

  pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")

  for step in range(total_steps):
    # Get data
    try:
      batch = next(train_iterator)
    except StopIteration:
      train_iterator = iter(train_loader)
      batch = next(train_iterator)

    loss, accuracy = model_fn(batch, model, criterion, device)
    batch_loss = loss.item()
    batch_accuracy = accuracy.item()

    # Updata model
    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
    
    # Log
    pbar.update()
    pbar.set_postfix(
      loss=f"{batch_loss:.2f}",
      accuracy=f"{batch_accuracy:.2f}",
      step=step + 1,
    )

    # Do validation
    if (step + 1) % valid_steps == 0:
      pbar.close()

      valid_accuracy = valid(valid_loader, model, criterion, device)

      # keep the best model
      if valid_accuracy > best_accuracy:
        best_accuracy = valid_accuracy
        best_state_dict = model.state_dict()

      pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")

    # Save the best model so far.
    if (step + 1) % save_steps == 0 and best_state_dict is not None:
      torch.save(best_state_dict, save_path)
      pbar.write(f"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})")

  pbar.close()
Ejemplo n.º 12
0
class Distiller:
    def __init__(self, params: dict, dataset: LmSeqsDataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths,
                                           k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler,
                                          group_ids=groups,
                                          batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler,
                                   batch_size=params.batch_size,
                                   drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        self.mlm = params.mlm
        if self.mlm:
            logger.info(f'Using MLM loss for LM step.')
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor(
                [params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(
                f'cuda:{params.local_rank}'
            ) if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(
                f'cuda:{params.local_rank}'
            ) if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info(f'Using CLM loss for LM step.')

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.: self.last_loss_mse = 0
        if self.alpha_cos > 0.: self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        if self.alpha_mse > 0.:
            self.mse_loss_fct = nn.MSELoss(reduction='sum')
        if self.alpha_cos > 0.:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = int(
            self.num_steps_epoch / params.gradient_accumulation_steps *
            params.n_epoch) + 1

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in student.named_parameters()
                if not any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            params.weight_decay
        }, {
            'params': [
                p for n, p in student.named_parameters()
                if any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            0.0
        }]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True)

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, 'log', 'train'))
            self.tensorboard.add_text(tag='config/training',
                                      text_string=str(self.params),
                                      global_step=0)
            self.tensorboard.add_text(tag='config/student',
                                      text_string=str(self.student_config),
                                      global_step=0)

    def prepare_batch_mlm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1),
                                  dtype=torch.long,
                                  device=lengths.device) < lengths[:, None])

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids['mask_token'])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = _token_ids_mask * (
            probs == 0).long() + _token_ids_real * (
                probs == 1).long() + _token_ids_rand * (probs == 2).long()
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -1  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels

    def prepare_batch_clm(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1),
                                  dtype=torch.long,
                                  device=lengths.device) < lengths[:, None])
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
        clm_labels[
            ~attn_mask] = -1  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            if self.mlm:
                pad_id = self.params.special_tok_ids['pad_token']
            else:
                pad_id = self.params.special_tok_ids['unk_token']
            padding_tensor = torch.zeros(bs2,
                                         pad,
                                         dtype=torch.long,
                                         device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader,
                            desc="-Iter",
                            disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f'cuda:{self.params.local_rank}') for t in batch)

                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(
                        batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(
                        batch=batch)
                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          lm_labels=lm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    'Last_loss':
                    f'{self.last_loss:.2f}',
                    'Avg_cum_loss':
                    f'{self.total_loss_epoch/self.n_iter:.2f}'
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             lm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        if self.mlm:
            s_logits, s_hidden_states = self.student(
                input_ids=input_ids,
                attention_mask=attention_mask)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, t_hidden_states = self.teacher(
                    input_ids=input_ids, attention_mask=attention_mask
                )  # (bs, seq_length, voc_size)
        else:
            s_logits, _, s_hidden_states = self.student(
                input_ids=input_ids,
                attention_mask=None)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                t_logits, _, t_hidden_states = self.teacher(
                    input_ids=input_ids,
                    attention_mask=None)  # (bs, seq_length, voc_size)
        assert s_logits.size() == t_logits.size()

        #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature,
                      dim=-1)) * (self.temperature)**2
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.:
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                        lm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.:
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_clm = self.lm_loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1))
            loss += self.alpha_clm * loss_clm

        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_hidden_states)  # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)

            s_hidden_states_slct = torch.masked_select(
                s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(
                t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(
                s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error('NaN detected')
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag='parameter_mean/' + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag='parameter_std/' + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch",
                                    scalar_value=self.total_loss_epoch /
                                    self.n_iter,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mlm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mlm",
                                        scalar_value=self.last_loss_mlm,
                                        global_step=self.n_total_iter)
        if self.alpha_clm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_clm",
                                        scalar_value=self.last_loss_clm,
                                        global_step=self.n_total_iter)
        if self.alpha_mse > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        if self.alpha_cos > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_cos",
                                        scalar_value=self.last_loss_cos,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()['used'] / 1_000_000,
            global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="global/speed",
                                    scalar_value=time.time() - self.last_log,
                                    global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f'{self.n_sequences_epoch} sequences have been trained during this epoch.'
        )

        if self.is_master:
            self.save_checkpoint(
                checkpoint_name=f'model_epoch_{self.epoch}.pth')
            self.tensorboard.add_scalar(tag='epoch/loss',
                                        scalar_value=self.total_loss_epoch /
                                        self.n_iter,
                                        global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = 'checkpoint.pth'):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(
            self.student, 'module') else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
Ejemplo n.º 13
0
        loss = distr_dalle(text, images, return_loss=True)

        if args.deepspeed:
            distr_dalle.backward(loss)
        else:
            loss.backward()

        clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)

        if args.deepspeed:
            distr_dalle.step()
            # Gradients are automatically zeroed after the step
        else:
            opt.step()
            opt.zero_grad()

        # Collective loss, averaged
        avg_loss = deepspeed_utils.average_all(loss)

        if deepspeed_utils.is_root_worker():
            torch.cuda.empty_cache()
            log = {}

            if i % 10 == 0:
                print(epoch, i, f'loss - {avg_loss.item()}')

                log = {
                    **log, 'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item()
Ejemplo n.º 14
0
class Trainer():
    def __init__(self, config, pretrained=True):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.checkpoint = config['trainer']['checkpoint']
        self.export_weights = config['trainer']['export']
        self.metrics = config['trainer']['metrics']
        logger = config['trainer']['log']

        if logger:
            self.logger = Logger(logger)

        if pretrained:
            weight_file = download_weights(**config['pretrain'],
                                           quiet=config['quiet'])
            self.load_weights(weight_file)

        self.iter = 0

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer, **config['optimizer'])
        #        self.optimizer = ScheduledOptim(
        #            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        #            #config['transformer']['d_model'],
        #            512,
        #            **config['optimizer'])

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        transforms = ImgAugTransform()

        self.train_gen = self.data_gen('train_{}'.format(self.dataset_name),
                                       self.data_root,
                                       self.train_annotation,
                                       transform=transforms)
        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                'valid_{}'.format(self.dataset_name), self.data_root,
                self.valid_annotation)

        self.train_losses = []

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = 0

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_loss = self.validate()
                acc_full_seq, acc_per_char = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char)
                print(info)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt_output.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['img'], self.model)
            else:
                translated_sentence = translate(batch['img'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())

            img_files.extend(batch['filenames'])

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files

    def precision(self, sample=None):

        pred_sents, actual_sents, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')

        return acc_full_seq, acc_per_char

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

        for vis_idx in range(0, len(img_files)):
            img_path = img_files[vis_idx]
            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]

            img = Image.open(open(img_path, 'rb'))
            plt.figure()
            plt.imshow(img)
            plt.title('pred: {} - actual: {}'.format(pred_sent, actual_sent),
                      loc='left',
                      fontdict=fontdict)
            plt.axis('off')

        plt.show()

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                plt.figure()
                plt.title('sent: {}'.format(sent),
                          loc='center',
                          fontname=fontname)
                plt.imshow(img)
                plt.axis('off')

                n += 1
                if n >= sample:
                    plt.show()
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        optim = ScheduledOptim(
            Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-09),
            self.config['transformer']['d_model'], **self.config['optimizer'])

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape'.format(name))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames']
        }

        return batch

    def data_gen(self, lmdb_path, data_root, annotation, transform=None):
        dataset = OCRDataset(
            lmdb_path=lmdb_path,
            root_dir=data_root,
            annotation_path=annotation,
            vocab=self.vocab,
            transform=transform,
            image_height=self.config['dataset']['image_height'],
            image_min_width=self.config['dataset']['image_min_width'],
            image_max_width=self.config['dataset']['image_max_width'])

        sampler = ClusterRandomSampler(dataset, self.batch_size, True)
        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=False,
                         drop_last=False,
                         **self.config['dataloader'])

        return gen

    def data_gen_v1(self, lmdb_path, data_root, annotation):
        data_gen = DataGen(
            data_root,
            annotation,
            self.vocab,
            'cpu',
            image_height=self.config['dataset']['image_height'],
            image_min_width=self.config['dataset']['image_min_width'],
            image_max_width=self.config['dataset']['image_max_width'])

        return data_gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  #flatten(0, 1)
        tgt_output = tgt_output.view(-1)  #flatten()

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item
Ejemplo n.º 15
0
def train(args):
    logger = log.get_logger(__name__)

    with open(Path(args.config_base_path, args.config).with_suffix(".yaml"), 'r') as f:
        config = yaml.safe_load(f)

    train_transforms = transforms.get_train_transforms()
    val_transforms = transforms.get_val_transforms()

    logger.info("Loading the dataset...")
    if config['dataset']['name'] == 'coco_subset':
        # TODO: Look into train_transforms hiding the objects
        # Transform in such a way that this can't be the case
        train_dataset = CocoSubset(config['dataset']['coco_path'],
                                   config['dataset']['target_classes'],
                                   train_transforms,
                                   'train',
                                   config['dataset']['train_val_split'])

        val_dataset = CocoSubset(config['dataset']['coco_path'],
                                 config['dataset']['target_classes'],
                                 val_transforms,
                                 'val',
                                 config['dataset']['train_val_split'])
    else:
        raise ValueError("Dataset name not recognized or implemented")

    train_loader = DataLoader(train_dataset,
                              config['training']['batch_size'],
                              shuffle=True,
                              collate_fn=data_utils.collate_fn)

    val_loader = DataLoader(val_dataset,
                            config['training']['batch_size'],
                            shuffle=True,
                            collate_fn=data_utils.collate_fn)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint_manager = CheckpointManager(args.config, args.save_every)

    logger.info("Loading model...")
    model = models.DETR(config['dataset']['num_classes'],
                        config['model']['dim_model'],
                        config['model']['n_heads'],
                        n_queries=config['model']['n_queries'],
                        head_type=config['model']['head_type'])

    # TODO: implement scheduler
    optim = AdamW(model.parameters(), config['training']['lr'])  # pending

    if args.mode == 'pretrained':
        model.load_demo_state_dict('data/state_dicts/detr_demo.pth')
    elif args.mode == 'checkpoint':
        state_dict, optim_dict = checkpoint_manager.load_checkpoint('latest')
        model.load_state_dict(state_dict)
        optim.load_state_dict(optim_dict)

    if args.train_section == 'head':
        to_train = ['ffn']
    elif args.train_section == 'backbone':
        to_train = ['backbone', 'conv']
    else:
        to_train = ['ffn', 'backbone', 'conv', 'transformer', 'row', 'col', 'object']

    # Freeze everything but the modules that are in to_train
    for name, param in model.named_parameters():
        if not any(map(name.startswith, to_train)):
            param.requires_grad = False

    model.to(device)

    matcher = models.HungarianMatcher(config['losses']['lambda_matcher_classes'],
                                      config['losses']['lambda_matcher_giou'],
                                      config['losses']['lambda_matcher_l1'])

    loss_fn = models.DETRLoss(config['losses']['lambda_loss_classes'],
                              config['losses']['lambda_loss_giou'],
                              config['losses']['lambda_loss_l1'],
                              config['dataset']['num_classes'],
                              config['losses']['no_class_weight'])

    # writer = SummaryWriter(log_dir=Path(__file__)/'logs/tensorboard')
    # maybe image with boxes every now and then
    # maybe look into add_hparams

    logger.info("Starting training...")
    loss_hist = deque()
    loss_desc = "Loss: n/a"

    update_every_n_steps = config['training']['effective_batch_size'] // config['training']['batch_size']
    steps = 1

    starting_epoch = checkpoint_manager.current_epoch

    for epoch in range(starting_epoch, config['training']['epochs']):
        epoch_desc = f"Epoch [{epoch}/{config['training']['epochs']}]"

        for images, labels in tqdm(train_loader, f"{epoch_desc} | {loss_desc}"):
            images = images.to(device)
            labels = data_utils.labels_to_device(labels, device)

            output = model(images)
            matching_indices = matcher(output, labels)
            matching_indices = data_utils.indices_to_device(matching_indices, device)

            loss = loss_fn(output, labels, matching_indices) / update_every_n_steps
            loss_hist.append(loss.item() * update_every_n_steps)
            loss.backward()

            if steps % update_every_n_steps == 0:
                optim.step()
                optim.zero_grad()

            steps += 1

        checkpoint_manager.step(model, optim, sum(loss_hist) / len(loss_hist))

        loss_desc = f"Loss: {sum(loss_hist)/len(loss_hist)}"
        loss_hist.clear()

        if (epoch % args.eval_every == 0) and epoch != 0:
            validation_loop(model, matcher, val_loader, loss_fn, device)

    checkpoint_manager.save_checkpoint(model, optim, sum(loss_hist) / len(loss_hist))
Ejemplo n.º 16
0
class Model:
    def __init__(self, local_rank=-1, arbitrary=False):
        if arbitrary == True:
            self.flownet = IFNet_m()
        else:
            self.flownet = IFNet()
        self.device()
        self.optimG = AdamW(
            self.flownet.parameters(), lr=1e-6,
            weight_decay=1e-3)  # use large weight decay may avoid NaN loss
        self.epe = EPE()
        self.lap = LapLoss()
        self.sobel = SOBEL()
        if local_rank != -1:
            self.flownet = DDP(self.flownet,
                               device_ids=[local_rank],
                               output_device=local_rank)

    def train(self):
        self.flownet.train()

    def eval(self):
        self.flownet.eval()

    def device(self):
        self.flownet.to(device)

    def load_model(self, path, rank=0):
        def convert(param):
            return {
                k.replace("module.", ""): v
                for k, v in param.items() if "module." in k
            }

        if rank <= 0:
            self.flownet.load_state_dict(
                convert(torch.load('{}/flownet.pkl'.format(path))))

    def save_model(self, path, rank=0):
        if rank == 0:
            torch.save(self.flownet.state_dict(),
                       '{}/flownet.pkl'.format(path))

    def inference(self,
                  img0,
                  img1,
                  scale_list=[4, 2, 1],
                  TTA=False,
                  timestep=0.5):
        imgs = torch.cat((img0, img1), 1)
        flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(
            imgs, scale_list, timestep=timestep)
        if TTA == False:
            return merged[2]
        else:
            flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(
                imgs.flip(2).flip(3), scale_list, timestep=timestep)
            return (merged[2] + merged2[2].flip(2).flip(3)) / 2

    def update(self,
               imgs,
               gt,
               learning_rate=0,
               mul=1,
               training=True,
               flow_gt=None):
        for param_group in self.optimG.param_groups:
            param_group['lr'] = learning_rate
        img0 = imgs[:, :3]
        img1 = imgs[:, 3:]
        if training:
            self.train()
        else:
            self.eval()
        flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(
            torch.cat((imgs, gt), 1), scale=[4, 2, 1])
        loss_l1 = (self.lap(merged[2], gt)).mean()
        loss_tea = (self.lap(merged_teacher, gt)).mean()
        if training:
            self.optimG.zero_grad()
            loss_G = loss_l1 + loss_tea + loss_distill * 0.01
            loss_G.backward()
            self.optimG.step()
        else:
            flow_teacher = flow[2]
        return merged[2], {
            'merged_tea': merged_teacher,
            'mask': mask,
            'mask_tea': mask,
            'flow': flow[2][:, :2],
            'flow_tea': flow_teacher,
            'loss_l1': loss_l1,
            'loss_tea': loss_tea,
            'loss_distill': loss_distill,
        }
Ejemplo n.º 17
0
def main():
    # 如果可以使用GPU运算,则使用GPU,否则使用CPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Use " + str(device))

    # 图片预处理的方法
    img_transform = transforms.Compose([
        # 将图片转换为tensor类型并缩放到[0,1]的区间内
        transforms.ToTensor(),
        # 将图片再缩放到[-1.1]的区间内
        transforms.Normalize((0.5, ), (0.5, )),
    ])

    # 创建输出文件夹
    if not os.path.exists(config.output_path):
        os.mkdir(config.output_path)

    # 创建dataset
    mnist_dataset = Digit_train_Dataset(pd.read_csv("MNIST.csv"),
                                        transform=img_transform)

    # 创建dataloader
    mnist_loader = DataLoader(dataset=mnist_dataset,
                              batch_size=config.batchSize,
                              shuffle=True)

    # 从model中获取判别器D和生成器G的网络模型
    G_model = get_G_model(config.from_old_model, device, config.G_model_path,
                          config.G_type)
    D_model = get_D_model(config.from_old_model, device, config.D_model_path)

    # 定义G和D的优化器,此处使用AdamW优化器,学习率为1e-4
    G_optimizer = AdamW(G_model.parameters(), lr=1e-4, weight_decay=1e-6)
    D_optimizer = AdamW(D_model.parameters(), lr=1e-4, weight_decay=1e-6)

    # 损失函数
    criterion = config.criterion

    # 记录训练时间
    train_start = time.time()

    # 开始训练的每一个epoch
    for epoch in range(config.epochs):
        print("start epoch " + str(epoch + 1) + ":")
        # 定义一些变量用于记录进度和损失
        batch_num = len(mnist_loader)
        D_loss_sum = 0
        G_loss_sum = 0
        count = 0

        # 从dataloader中提取数据
        for index, images in enumerate(mnist_loader):
            count += 1
            # 将图片放入运算设备的内存
            images = images.to(device)

            # 定义真标签,使用标签平滑的策略,生成0.9到1之间的随机数作为真实标签
            real_labels = (1 - torch.rand(config.batchSize, 1) / 10).to(device)

            # 定义假标签,单向平滑,因此不对生成器标签进行平滑处理,全0
            fake_labels = Variable(torch.zeros(config.batchSize, 1)).to(device)

            # 将随机的初始数据喂入生成器生成假图像
            img_seeds = torch.randn(config.batchSize,
                                    config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)

            # 记录真假标签是否被交换过
            exchange_labels = False

            # 有一定概率在训练判别器时交换label
            if random.uniform(0, 1) < config.D_train_label_exchange:
                real_labels, fake_labels = fake_labels, real_labels
                exchange_labels = True

            # 训练判断器D
            D_optimizer.zero_grad()
            # 用真样本输入判别器
            real_output = D_model(images)
            # 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签
            if len(real_labels) > len(real_output):
                D_loss_real = criterion(real_output,
                                        real_labels[:len(real_output)])
            else:
                D_loss_real = criterion(real_output, real_labels)
            # 用假样本输入判别器
            fake_output = D_model(fake_images)
            D_loss_fake = criterion(fake_output, fake_labels)
            # 将真样本与假样本损失相加,得到判别器的损失
            D_loss = D_loss_real + D_loss_fake
            D_loss_sum += D_loss.item()

            # 重置优化器
            D_optimizer.zero_grad()
            # 用损失更新判别器D
            D_loss.backward()
            D_optimizer.step()

            # 如果之前交换过标签,此时再换回来
            if exchange_labels:
                real_labels, fake_labels = fake_labels, real_labels

            # 训练生成器G
            # 将随机种子数喂入生成器G生成假数据
            img_seeds = torch.randn(config.batchSize,
                                    config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)
            # 将假数据输入判别器
            fake_output = D_model(fake_images)
            # 将假数据的判别结果与真实标签对比得到损失
            G_loss = criterion(fake_output, real_labels)
            G_loss_sum += G_loss.item()

            # 重置优化器
            G_optimizer.zero_grad()
            # 利用损失更新生成器G
            G_loss.backward()
            G_optimizer.step()

            # 打印程序工作进度
            if (index + 1) % 200 == 0:
                print("Epoch: %2d, Batch: %4d / %4d" %
                      (epoch + 1, index + 1, batch_num))

        # 在每个epoch结束时保存模型参数到磁盘文件
        torch.save(G_model.state_dict(), config.G_model_path)
        torch.save(D_model.state_dict(), config.D_model_path)

        # 在每个epoch结束时输出一组生成器产生的图片到输出文件夹
        img_seeds = torch.randn(config.batchSize,
                                config.img_seed_dim).to(device)
        fake_images = G_model(img_seeds).cuda().data
        # 将假图像缩放到[0,1]的区间
        fake_images = 0.5 * (fake_images + 1)
        fake_images = fake_images.clamp(0, 1)
        # 连接所有生成的图片然后用自带的save_image()函数输出到磁盘文件
        fake_images = fake_images.view(-1, 1, 28, 28)
        save_image(fake_images, config.output_path + str(epoch + 1) + '.png')

        # 打印该epoch的损失,时间等数据用于参考
        print("D_loss:", round(D_loss_sum / count, 3))
        print("G_loss:", round(G_loss_sum / count, 3))
        current_time = time.time()
        pass_time = int(current_time - train_start)
        time_string = str(pass_time // 3600) + " hours, " + str(
            (pass_time % 3600) // 60) + " minutes, " + str(
                pass_time % 60) + " seconds."
        print("Time pass:"******"Done.")
Ejemplo n.º 18
0
class Model:
    def __init__(self, local_rank=-1):
        self.flownet = IFNet()
        self.contextnet = ContextNet()
        self.fusionnet = FusionNet()
        self.device()
        self.optimG = AdamW(itertools.chain(self.flownet.parameters(),
                                            self.contextnet.parameters(),
                                            self.fusionnet.parameters()),
                            lr=1e-6,
                            weight_decay=1e-5)
        self.schedulerG = optim.lr_scheduler.CyclicLR(self.optimG,
                                                      base_lr=1e-6,
                                                      max_lr=1e-3,
                                                      step_size_up=8000,
                                                      cycle_momentum=False)
        self.epe = EPE()
        self.ter = Ternary()
        self.sobel = SOBEL()
        if local_rank != -1:
            self.flownet = DDP(self.flownet,
                               device_ids=[local_rank],
                               output_device=local_rank)
            self.contextnet = DDP(self.contextnet,
                                  device_ids=[local_rank],
                                  output_device=local_rank)
            self.fusionnet = DDP(self.fusionnet,
                                 device_ids=[local_rank],
                                 output_device=local_rank)

    def train(self):
        self.flownet.train()
        self.contextnet.train()
        self.fusionnet.train()

    def eval(self):
        self.flownet.eval()
        self.contextnet.eval()
        self.fusionnet.eval()

    def device(self):
        self.flownet.to(device)
        self.contextnet.to(device)
        self.fusionnet.to(device)

    def load_model(self, path, rank=-1):
        def convert(param):
            if rank == -1:
                return {
                    k.replace("module.", ""): v
                    for k, v in param.items() if "module." in k
                }
            else:
                return param

        if rank <= 0:
            self.flownet.load_state_dict(
                convert(
                    torch.load('{}/flownet.pkl'.format(path),
                               map_location=device)))
            self.contextnet.load_state_dict(
                convert(
                    torch.load('{}/contextnet.pkl'.format(path),
                               map_location=device)))
            self.fusionnet.load_state_dict(
                convert(
                    torch.load('{}/unet.pkl'.format(path),
                               map_location=device)))

    def save_model(self, path, rank):
        if rank == 0:
            torch.save(self.flownet.state_dict(),
                       '{}/flownet.pkl'.format(path))
            torch.save(self.contextnet.state_dict(),
                       '{}/contextnet.pkl'.format(path))
            torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))

    def predict(self, imgs, flow, training=True, flow_gt=None):
        img0 = imgs[:, :3]
        img1 = imgs[:, 3:]
        flow = F.interpolate(
            flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
        c0 = self.contextnet(img0, flow[:, :2])
        c1 = self.contextnet(img1, flow[:, 2:4])
        refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
            img0, img1, flow, c0, c1, flow_gt)
        res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
        mask = torch.sigmoid(refine_output[:, 3:4])
        merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
        pred = merged_img + res
        pred = torch.clamp(pred, 0, 1)
        if training:
            return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
        else:
            return pred

    def inference(self, img0, img1):
        imgs = torch.cat((img0, img1), 1)
        flow, _ = self.flownet(imgs)
        return self.predict(imgs, flow, training=False)

    def update(self,
               imgs,
               gt,
               learning_rate=0,
               mul=1,
               training=True,
               flow_gt=None):
        for param_group in self.optimG.param_groups:
            param_group['lr'] = learning_rate
        if training:
            self.train()
        else:
            self.eval()
        flow, flow_list = self.flownet(imgs)
        pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
            imgs, flow, flow_gt=flow_gt)
        loss_ter = self.ter(pred, gt).mean()
        if training:
            with torch.no_grad():
                loss_flow = torch.abs(warped_img0_gt - gt).mean()
                loss_mask = torch.abs(merged_img - gt).sum(
                    1, True).float().detach()
                loss_mask = F.interpolate(loss_mask,
                                          scale_factor=0.5,
                                          mode="bilinear",
                                          align_corners=False).detach()
                flow_gt = (F.interpolate(flow_gt,
                                         scale_factor=0.5,
                                         mode="bilinear",
                                         align_corners=False) * 0.5).detach()
            loss_cons = 0
            for i in range(3):
                loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
                loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
            loss_cons = loss_cons.mean() * 0.01
        else:
            loss_cons = torch.tensor([0])
            loss_flow = torch.abs(warped_img0 - gt).mean()
            loss_mask = 1
        loss_l1 = (((pred - gt)**2 + 1e-6)**0.5).mean()
        if training:
            self.optimG.zero_grad()
            loss_G = loss_l1 + loss_cons + loss_ter
            loss_G.backward()
            self.optimG.step()
        return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
Ejemplo n.º 19
0
  def train_model(self,
                  train_x,
                  train_y,
                  train_mask,
                  dev_x,
                  dev_y,
                  dev_mask,
                  test_x,
                  test_y,
                  test_mask,
                  lr=1e-5,
                  batch_size=16,
                  aux_batch_size=4,
                  use_aux=False,
                  sampling='uniform',
                  all_neg=False,
                  model_path='models'):
    model_name = "model_{}".format("ns" if use_aux else "base")
    if use_aux:
      model_name += "_{}".format("allneg" if all_neg else "normal")
    all_params = [p for _, p in self.bert.named_parameters()] + [self.linear.weight, self.linear.bias]
    num_train_steps = len(train_x) // batch_size * 50
    optimizer = AdamW(all_params, lr=lr)
    scheduler = get_linear_schedule_with_warmup(optimizer,
					      num_warmup_steps=int(num_train_steps * 0.1),
					      num_training_steps=num_train_steps)
    train_y_numpy = [l.numpy() for l in train_y]
    tor = 0
    max_score = -1
    label_group = make_label_group(train_y_numpy)
    stack = []
    count = 1
    steps = 0
    train_losses_main = []
    train_losses_aux = []
    dev_losses_main = []
    dev_losses_aux = []
    all_idx = set(list(range(train_x.shape[0])))

    while tor < 10:
      if count > 50:
          break
      st = time.time()
      print('epoch ', count)
      count += 1
      self.train()
      for bx, by, bmask, prg, bidx in self.data_generator(train_x, train_y, train_mask):
        bx = bx.to(device)
        by = by.to(device)
        bmask = bmask.to(device)

        if use_aux:
          aux_x, aux_y, aux_label, aux_idx = self.aux_task_sampling(
									train_x,
									train_y,
									by,
									bidx,
									label_group,
									batch=aux_batch_size,
									at_random=('rand' in sampling),
									all_neg=all_neg)
          aux_mask = torch.cuda.LongTensor([train_mask[i].numpy() for i in aux_idx])
          aux_x = torch.stack(aux_x).type(torch.LongTensor).cuda()
          aux_y = torch.cuda.LongTensor(aux_y)
          all_loss, main_loss, aux_loss = self.calc_loss(bx, aux_x, bmask, aux_mask, by, aux_y, all_neg, use_aux=True)
        else:
          all_loss, main_loss = self.calc_loss(bx, None, bmask, None, by, None, False, use_aux=False)

        optimizer.zero_grad()
        all_loss.backward()
        loss_value = all_loss.detach().cpu().numpy()
        optimizer.step()
        scheduler.step()
        print('progress: {:.2f}%, loss = {:.5f}\r'.format(prg * 100, loss_value), end='', flush=True)
        steps += 1
      print('')
      self.eval()
      with torch.no_grad():
        score, _, losses = self.evaluate(dev_x, dev_y, dev_mask, all_neg=all_neg)
        score_test, _, _ = self.evaluate(test_x, test_y, test_mask, all_neg=all_neg)
      print('dev exact match = ', score, flush=True)
      print('test exact match = ', score_test, flush=True)

      if max_score < score:
        max_score = score
        tor = 0
        with torch.no_grad():
          max_score_test, preds_test, _ = self.evaluate(test_x, test_y, test_mask)
        torch.save(self.state_dict(), os.path.join(model_path, model_name) + "_{}".format(len(os.listdir(model_path))))
      else:
        tor += 1
      ed = time.time()
      print('time = ', ed - st)
      self.train()

    print('finish')
    all_losses = {
        'main_losses_train': train_losses_main,
        'aux_losses_train': train_losses_aux,
        'main_losses_dev': dev_losses_main,
        'aux_losses_dev': dev_losses_aux}
    return max_score_test, max_score, all_losses
Ejemplo n.º 20
0
        def run(self):
            """Run the training"""
            start = timeit.default_timer()
            no_improve_count = 0

            AdamW_optim = AdamW(self.weights, lr=self.init_lr)
            SGD_optim = torch.optim.SGD(self.biases, lr=self.init_lr)

            AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                AdamW_optim, factor=0.5, patience=100, threshold=0)
            SGD_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                SGD_optim, factor=0.5, patience=100, threshold=0)

            while True:
                rmse = self.evaluate(self.validation_set)
                learning_rate = AdamW_optim.param_groups[0]['lr']
                if learning_rate < self.min_lr or AdamW_scheduler.last_epoch > self.nmax:
                    break

                # checkpoint
                if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
                    no_improve_count = 0
                    torch.save(self.nn.state_dict(), self.model_checkpoint)
                else:
                    no_improve_count += 1

                if no_improve_count > self.max_nonimprove:
                    break

                AdamW_scheduler.step(rmse)
                SGD_scheduler.step(rmse)

                if self.tensorboard is not None:
                    self.tensorboard.add_scalar('validation_rmse', rmse,
                                                AdamW_scheduler.last_epoch)
                    self.tensorboard.add_scalar('best_validation_rmse',
                                                AdamW_scheduler.best,
                                                AdamW_scheduler.last_epoch)
                    self.tensorboard.add_scalar('learning_rate', learning_rate,
                                                AdamW_scheduler.last_epoch)
                    self.tensorboard.add_scalar('no_improve_count_vs_epoch',
                                                no_improve_count,
                                                AdamW_scheduler.last_epoch)

                for i, properties in self.tqdm(
                        enumerate(self.training_set),
                        total=len(self.training_set),
                        desc='epoch {}'.format(AdamW_scheduler.last_epoch)):
                    species = properties['species'].to(self.device)
                    coordinates = properties['coordinates'].to(
                        self.device).float()
                    true_energies = properties['energies'].to(
                        self.device).float()
                    num_atoms = (species >= 0).sum(dim=1,
                                                   dtype=true_energies.dtype)
                    _, predicted_energies = self.model((species, coordinates))
                    loss = (self.mse_se(predicted_energies, true_energies) /
                            num_atoms.sqrt()).mean()
                    AdamW_optim.zero_grad()
                    SGD_optim.zero_grad()
                    loss.backward()
                    AdamW_optim.step()
                    SGD_optim.step()

                    # write current batch loss to TensorBoard
                    if self.tensorboard is not None:
                        self.tensorboard.add_scalar(
                            'batch_loss', loss,
                            AdamW_scheduler.last_epoch * len(self.training_set)
                            + i)

                # log elapsed time
                elapsed = round(timeit.default_timer() - start, 2)
                if self.tensorboard is not None:
                    self.tensorboard.add_scalar('time_vs_epoch', elapsed,
                                                AdamW_scheduler.last_epoch)
Ejemplo n.º 21
0
def prepare_data(dataset):
    data_path = 'data/{}'.format(dataset)
    data_full_path = '{}.npz'.format(data_path)
    if exists(data_full_path):
        data = np.load(data_full_path, allow_pickle=True)
        if 'cifar' in dataset or 'celeb' in dataset:
            return data['xs'], data['hs'], data['ys'], data['ps']
        else:
            return data['xs'], data['ys'], data['ps']

    if 'cifar' in dataset:
        device = 'cuda:0'
        model = models.resnet152(num_classes=10).to(device)
        model_path = 'data/cifar_resnet152'
        transform = Compose([
            ToTensor(),
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        data = CIFAR10(root='data', download=True, transform=transform)

        if exists(model_path):
            model.load_state_dict(load(model_path))
        else:
            loader = DataLoader(data, batch_size=1024, shuffle=True)
            loss = CrossEntropyLoss()
            optim = AdamW(model.parameters(), amsgrad=True)

            for epoch in range(100):
                epoch_loss = 0
                for _data in loader:
                    xs, ys = _data
                    optim.zero_grad()
                    _loss = loss(model(xs.to(device)), ys.to(device))
                    _loss.backward()
                    optim.step()

                    epoch_loss += _loss.item()
                print('Epoch {}: Loss: {}'.format(epoch + 1, epoch_loss))
            save(model.state_dict(), model_path)

        xs, ys = data.data, np.array(data.targets)
        first_digit, second_digit = int(dataset[5]), int(dataset[7])
        mask = np.logical_or(ys == first_digit, ys == second_digit)
        xs = xs[mask]
        ys = np.array([1 if y == first_digit else 0 for y in ys[mask]])

        loader = DataLoader(data, batch_size=1024)
        hs = None
        model.eval()
        for _data in loader:
            _xs, _ys = _data
            _mask = (_ys == first_digit) | (_ys == second_digit)
            if len(_mask) > 0:
                _hs = resnet_fc(model, _xs[_mask].to(device))
                _hs = _hs.cpu().detach().numpy()
                hs = _hs if hs is None else np.concatenate((hs, _hs), axis=0)

        clf = LogisticRegression(n_jobs=-1, max_iter=100000)
        clf.fit(hs, ys)
        ps = clf.predict_proba(hs)[:, 1]

        np.savez(data_path, xs=xs, hs=hs, ys=ys, ps=ps)
        return xs, hs, ys, ps

    if 'mnist' in dataset:
        xs, ys = fetch('mnist_784')
        first_digit, second_digit = dataset[5], dataset[7]
    elif 'fashion' in dataset:
        xs, ys = fetch('Fashion-MNIST')
        first_digit, second_digit = dataset[7], dataset[9]
    elif 'kuzushi' in dataset:
        xs, ys = fetch('Kuzushiji-MNIST')
        first_digit, second_digit = dataset[7], dataset[9]
    mask = np.logical_or(ys == first_digit, ys == second_digit)
    xs = xs[mask] / 255
    ys = np.array([1 if y == first_digit else 0 for y in ys[mask]])

    clf = LogisticRegression(n_jobs=-1, max_iter=100000)
    clf.fit(xs, ys)
    ps = clf.predict_proba(xs)[:, 1]

    np.savez(data_path, xs=xs, ys=ys, ps=ps)
    return xs, ys, ps
Ejemplo n.º 22
0
class Seq2seqKpGen(object):
    """High level model that handles intializing the underlying network
    architecture, saving, updating examples, and predicting examples.
    """

    # --------------------------------------------------------------------------
    # Initialization
    # --------------------------------------------------------------------------

    def __init__(self, args, word_dict, state_dict=None):
        # Book-keeping.
        self.args = args
        self.word_dict = word_dict
        self.args.vocab_size = len(word_dict)
        self.updates = 0

        self.network = Sequence2Sequence(self.args, self.word_dict)
        if state_dict:
            self.network.load_state_dict(state_dict)

    def activate_fp16(self):
        if not hasattr(self, 'optimizer'):
            self.network.half()  # for testing only
            return
        try:
            global amp
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        # https://github.com/NVIDIA/apex/issues/227
        assert self.optimizer is not None
        self.network, self.optimizer = amp.initialize(self.network,
                                                      self.optimizer,
                                                      opt_level=self.args.fp16_opt_level)

    def init_optimizer(self, optim_state=None, sched_state=None):
        def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
            def lr_lambda(current_step: int):
                if current_step < num_warmup_steps:
                    return float(current_step) / float(max(1.0, num_warmup_steps))
                return 1.0

            return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.network.named_parameters()
                           if not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {"params": [p for n, p in self.network.named_parameters()
                        if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0},
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)
        self.scheduler = get_constant_schedule_with_warmup(self.optimizer, self.args.warmup_steps)

        if optim_state:
            self.optimizer.load_state_dict(optim_state)
            if self.args.device.type == 'cuda':
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(self.args.device)
        if sched_state:
            self.scheduler.load_state_dict(sched_state)

    # --------------------------------------------------------------------------
    # Learning
    # --------------------------------------------------------------------------

    def update(self, ex):
        """Forward a batch of examples; step the optimizer to update weights."""
        if not self.optimizer:
            raise RuntimeError('No optimizer set.')

        # Train mode
        self.network.train()

        source_map, alignment = None, None
        if self.args.copy_attn:
            source_map = make_src_map(ex['src_map']).to(self.args.device)
            alignment = align(ex['alignment']).to(self.args.device)

        source_rep = ex['source_rep'].to(self.args.device)
        source_len = ex['source_len'].to(self.args.device)
        target_rep = ex['target_rep'].to(self.args.device)
        target_len = ex['target_len'].to(self.args.device)

        # Run forward
        ml_loss, loss_per_token = self.network(source=source_rep,
                                               source_len=source_len,
                                               target=target_rep,
                                               target_len=target_len,
                                               src_map=source_map,
                                               alignment=alignment)

        loss = ml_loss.mean() if self.args.n_gpu > 1 else ml_loss
        if self.args.fp16:
            global amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
            clip_grad_norm_(amp.master_params(self.optimizer), self.args.grad_clipping)
        else:
            loss.backward()
            clip_grad_norm_(self.network.parameters(), self.args.grad_clipping)

        self.updates += 1
        self.optimizer.step()
        self.scheduler.step()  # Update learning rate schedule
        self.optimizer.zero_grad()

        loss_per_token = loss_per_token.mean() if self.args.n_gpu > 1 else loss_per_token
        loss_per_token = loss_per_token.item()
        loss_per_token = 10 if loss_per_token > 10 else loss_per_token
        perplexity = math.exp(loss_per_token)

        return {
            'ml_loss': loss.item(),
            'perplexity': perplexity
        }

    # --------------------------------------------------------------------------
    # Prediction
    # --------------------------------------------------------------------------

    def predict(self, ex, replace_unk=False):
        """Forward a batch of examples only to get predictions.
        Args:
            ex: the batch examples
            replace_unk: replace `unk` tokens while generating predictions
            src_raw: raw source (passage); required to replace `unk` term
        Output:
            predictions: #batch predicted sequences
        """

        def convert_text_to_string(text):
            """ Converts a sequence of tokens (string) in a single string. """
            out_string = text.replace(" ##", "").strip()
            return out_string

        self.network.eval()

        source_map, alignment = None, None
        blank, fill = None, None
        if self.args.copy_attn:
            source_map = make_src_map(ex['src_map']).to(self.args.device)
            alignment = align(ex['alignment']).to(self.args.device)
            blank, fill = collapse_copy_scores(self.word_dict, ex['src_vocab'])

        source_rep = ex['source_rep'].to(self.args.device)
        source_len = ex['source_len'].to(self.args.device)

        decoder_out = self.network(source=source_rep,
                                   source_len=source_len,
                                   target=None,
                                   target_len=None,
                                   src_map=source_map,
                                   alignment=alignment,
                                   max_len=self.args.max_tgt_len,
                                   tgt_dict=self.word_dict,
                                   blank=blank, fill=fill,
                                   source_vocab=ex['src_vocab'])

        dec_probs = torch.exp(decoder_out['dec_log_probs'])
        predictions, scores = tens2sen_score(decoder_out['predictions'], dec_probs,
                                             self.word_dict, ex['src_vocab'])
        if replace_unk:
            for i in range(len(predictions)):
                enc_dec_attn = decoder_out['attentions'][i]
                if self.args.model_type == 'transformer':
                    # tgt_len x num_heads x src_len
                    assert enc_dec_attn.dim() == 3
                    enc_dec_attn = enc_dec_attn.mean(1)
                predictions[i] = replace_unknown(predictions[i], enc_dec_attn,
                                                 src_raw=ex['source'][i].tokens)

        for bidx in range(ex['batch_size']):
            for i in range(len(predictions[bidx])):
                if predictions[bidx][i] == constants.KP_SEP:
                    scores[bidx][i] = constants.KP_SEP
                elif predictions[bidx][i] == constants.PRESENT_EOS:
                    scores[bidx][i] = constants.PRESENT_EOS
                else:
                    assert isinstance(scores[bidx][i], float)
                    scores[bidx][i] = str(scores[bidx][i])

        predictions = [' '.join(item) for item in predictions]
        scores = [' '.join(item) for item in scores]

        present_kps = []
        absent_kps = []
        present_kp_scores = []
        absent_kp_scores = []
        for bidx in range(ex['batch_size']):
            keyphrases = predictions[bidx].split(constants.PRESENT_EOS)
            kp_scores = scores[bidx].split(constants.PRESENT_EOS)
            pkps = (' %s ' % constants.KP_SEP).join(keyphrases[:-1])
            pkp_scores = (' %s ' % constants.KP_SEP).join(kp_scores[:-1])
            akps = keyphrases[-1]
            akp_scores = kp_scores[-1]

            pre_kps = []
            pre_kp_scores = []
            for pkp, pkp_s in zip(pkps.split(constants.KP_SEP),
                                  pkp_scores.split(constants.KP_SEP)):
                pkp = pkp.strip()
                if pkp:
                    pre_kps.append(convert_text_to_string(pkp))
                    t_scores = [float(i) for i in pkp_s.strip().split()]
                    _score = np.prod(t_scores) / len(t_scores)
                    pre_kp_scores.append(_score)

            present_kps.append(pre_kps)
            present_kp_scores.append(pre_kp_scores)

            abs_kps = []
            abs_kp_scores = []
            for akp, akp_s in zip(akps.split(constants.KP_SEP),
                                  akp_scores.split(constants.KP_SEP)):
                akp = akp.strip()
                if akp:
                    abs_kps.append(convert_text_to_string(akp))
                    t_scores = [float(i) for i in akp_s.strip().split()]
                    _score = np.prod(t_scores) / len(t_scores)
                    abs_kp_scores.append(_score)

            absent_kps.append(abs_kps)
            absent_kp_scores.append(abs_kp_scores)

        return {
            'present_kps': present_kps,
            'absent_kps': absent_kps,
            'present_kp_scores': present_kp_scores,
            'absent_kp_scores': absent_kp_scores
        }

    # --------------------------------------------------------------------------
    # Saving and loading
    # --------------------------------------------------------------------------

    def save(self, filename):
        network = self.network.module if hasattr(self.network, "module") \
            else self.network
        state_dict = copy.copy(network.state_dict())
        params = {
            'state_dict': state_dict,
            'word_dict': self.word_dict,
            'args': self.args,
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    def checkpoint(self, filename, epoch):
        network = self.network.module if hasattr(self.network, "module") \
            else self.network
        params = {
            'state_dict': network.state_dict(),
            'word_dict': self.word_dict,
            'args': self.args,
            'epoch': epoch,
            'updates': self.updates,
            'optim_dict': self.optimizer.state_dict(),
            'sched_dict': self.scheduler.state_dict(),
        }
        try:
            torch.save(params, filename)
        except BaseException:
            logger.warning('WARN: Saving failed... continuing anyway.')

    @staticmethod
    def load(filename, new_args=None):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        word_dict = saved_params['word_dict']
        state_dict = saved_params['state_dict']
        args = saved_params['args']
        if new_args:
            args = override_model_args(args, new_args)
        return Seq2seqKpGen(args, word_dict, state_dict)

    @staticmethod
    def load_checkpoint(filename):
        logger.info('Loading model %s' % filename)
        saved_params = torch.load(
            filename, map_location=lambda storage, loc: storage
        )
        word_dict = saved_params['word_dict']
        state_dict = saved_params['state_dict']
        epoch = saved_params['epoch']
        updates = saved_params['updates']
        optim_dict = saved_params['optim_dict']
        sched_dict = saved_params['sched_dict']
        args = saved_params['args']
        model = Seq2seqKpGen(args, word_dict, state_dict)
        model.updates = updates
        model.init_optimizer(optim_dict, sched_dict)
        return model, epoch

    # --------------------------------------------------------------------------
    # Runtime
    # --------------------------------------------------------------------------

    def to(self, device):
        self.network = self.network.to(device)

    def parallelize(self):
        self.network = torch.nn.DataParallel(self.network)
Ejemplo n.º 23
0
class Distiller:
    def __init__(self, params, dataloader, student, teacher, device):
        # Initializing Distiller
        self.params = params
        self.dump_path = params["dump_path"]
        self.student = student
        self.teacher = teacher
        self.device = device
        self.dataloader = dataloader

        self.temperature = params["temperature"]
        assert self.temperature > 0.0

        self.alpha_ce = params["alpha_ce"]
        self.alpha_mlm = params["alpha_mlm"]
        self.alpha_mse = params["alpha_mse"]
        self.alpha_cos = params["alpha_cos"]

        self.mlm_mask_prop = params["mlm_mask_prop"]
        assert 0.0 <= self.mlm_mask_prop <= 1.0

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0

        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        #  Initializing model optimizer
        assert params["gradient_accumulation_steps"] >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params["gradient_accumulation_steps"] *
                params["n_epoch"]) + 1)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                params["weight_decay"],
            },
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                0.0,
            },
        ]

        self.optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=params["learning_rate"],
            eps=params["adam_epsilon"],
            betas=(0.9, 0.98),
        )

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params["warmup_prop"])
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps,
        )

    def train(self):
        """
        The real training loop.
        """
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params["n_epoch"]):
            iter_bar = tqdm(self.dataloader, desc="-Iter")
            for batch in iter_bar:
                # batch = tuple(t.to(device) for t in batch)
                b_input_ids = batch["input_ids"].to(self.device)
                b_labels = batch["labels"].to(self.device)

                b_bool_attn_mask = batch["input_ids"] != 0
                b_bool_attn_mask.to(self.device)

                self.step(
                    input_ids=b_input_ids,
                    attention_mask=b_bool_attn_mask,
                    lm_labels=b_labels,
                )

                iter_bar.update()
            iter_bar.close()
            self.end_epoch()

        self.save_checkpoint(checkpoint_name="pytorch_model.bin")
        print("Training is finished")

    def step(self, input_ids, attention_mask, lm_labels):
        s_output = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )  # (bs, seq_length, voc_size)
        s_logits, s_hidden_states = s_output["logits"], s_output[
            "hidden_states"]
        with torch.no_grad():
            t_output = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )
            t_logits, t_hidden_states = t_output["logits"], t_output[
                "hidden_states"]

        assert s_logits.size() == t_logits.size()

        mask = ((lm_labels > -1).unsqueeze(-1).expand_as(s_logits)
                )  # (bs, seq_length, voc_size)
        # or  mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = (self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature, dim=-1),
        ) * (self.temperature)**2)
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.0:
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                        lm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm

        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse

        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_hidden_states)  # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)

            s_hidden_states_slct = torch.masked_select(
                s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(
                t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(
                s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.0:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            print("NaN detected")
            exit()

        if self.params["gradient_accumulation_steps"] > 1:
            loss = loss / self.params["gradient_accumulation_steps"]

        loss.backward()
        self.iter()
        if self.n_iter % self.params["gradient_accumulation_steps"] == 0:
            torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                           self.params["max_grad_norm"])
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """

        self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        mdl_to_save = (self.student.module
                       if hasattr(self.student, "module") else self.student)
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
Ejemplo n.º 24
0
class TD3Agent(AgentBase):
    """
    Twin Delayed Deep Deterministic (TD3) Policy Gradient.

    In short, it's a slightly modified/improved version of the DDPG. Compared to the DDPG in this package,
    which uses Guassian noise, this TD3 uses Ornstein–Uhlenbeck process as the noise.
    """

    name = "TD3"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 noise_scale: float = 0.2,
                 noise_sigma: float = 0.1,
                 **kwargs):
        """
        Parameters:
            state_size (int): Number of input dimensions.
            action_size (int): Number of output dimensions
            noise_scale (float): Added noise amplitude. Default: 0.2.
            noise_sigma (float): Added noise variance. Default: 0.1.

        Keyword parameters:
            hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128).
            actor_lr (float): Learning rate for the actor (policy). Default: 0.003.
            critic_lr (float): Learning rate for the critic (value function). Default: 0.003.
            gamma (float): Discount value. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.02.
            actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`.
            critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`.
            max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100.
            max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100.
            batch_size (int): Number of samples used in learning. Default: 64.
            buffer_size (int): Maximum number of samples to store. Default: 1e6.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            update_freq (int): Number of steps between each learning step. Default 1.
            number_updates (int): How many times to use learning step in the learning phase. Default: 1.
            action_min (float): Minimum returned action value. Default: -1.
            action_max (float): Maximum returned action value. Default: 1.
            action_scale (float): Multipler value for action. Default: 1.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(
            kwargs, "device", DEVICE)  # Default device is CUDA if available

        # Reason sequence initiation.
        self.state_size = state_size
        self.action_size = action_size

        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        self.actor = ActorBody(state_size,
                               action_size,
                               hidden_layers=hidden_layers).to(self.device)
        self.critic = DoubleCritic(state_size,
                                   action_size,
                                   CriticBody,
                                   hidden_layers=hidden_layers).to(self.device)
        self.target_actor = ActorBody(state_size,
                                      action_size,
                                      hidden_layers=hidden_layers).to(
                                          self.device)
        self.target_critic = DoubleCritic(state_size,
                                          action_size,
                                          CriticBody,
                                          hidden_layers=hidden_layers).to(
                                              self.device)

        # Noise sequence initiation
        # self.noise = GaussianNoise(shape=(action_size,), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=device)
        self.noise = OUProcess(shape=action_size,
                               scale=noise_scale,
                               sigma=noise_sigma,
                               device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-3))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-3))
        self.actor_optimizer = AdamW(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = AdamW(self.critic.parameters(), lr=critic_lr)
        self.max_grad_norm_actor: float = float(
            kwargs.get("max_grad_norm_actor", 100))
        self.max_grad_norm_critic: float = float(
            kwargs.get("max_grad_norm_critic", 100))
        self.action_min = float(self._register_param(kwargs, 'action_min',
                                                     -1.))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1.))
        self.action_scale = float(
            self._register_param(kwargs, 'action_scale', 1.))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = ReplayBuffer(self.batch_size, self.buffer_size)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.update_policy_freq = int(
            self._register_param(kwargs, 'update_policy_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.noise_reset_freq = int(
            self._register_param(kwargs, 'noise_reset_freq', 10000))

        # Breath, my child.
        self.reset_agent()
        self.iteration = 0
        self._loss_actor = 0.
        self._loss_critic = 0.

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.critic.reset_parameters()
        self.target_actor.reset_parameters()
        self.target_critic.reset_parameters()

    def act(self,
            state,
            epsilon: float = 0.0,
            training_mode=True) -> List[float]:
        """
        Agent acting on observations.

        When the training_mode is True (default) a noise is added to each action.
        """
        # Epsilon greedy
        if self._rng.random() < epsilon:
            rnd_actions = torch.rand(self.action_size) * (
                self.action_max - self.action_min) - self.action_min
            return rnd_actions.tolist()

        with torch.no_grad():
            state = to_tensor(state).float().to(self.device)
            action = self.actor(state)
            if training_mode:
                action += self.noise.sample()
            return (self.action_scale * torch.clamp(action, self.action_min,
                                                    self.action_max)).tolist()

    def target_act(self, staten, noise: float = 0.0):
        with torch.no_grad():
            staten = to_tensor(staten).float().to(self.device)
            action = self.target_actor(staten) + noise * self.noise.sample()
            return torch.clamp(action, self.action_min,
                               self.action_max).cpu().numpy().astype(
                                   np.float32)

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1
        self.buffer.add(state=state,
                        action=action,
                        reward=reward,
                        next_state=next_state,
                        done=done)

        if (self.iteration % self.noise_reset_freq) == 0:
            self.noise.reset_states()

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) <= self.batch_size:
            return

        if not (self.iteration % self.update_freq) or not (
                self.iteration % self.update_policy_freq):
            for _ in range(self.number_updates):
                # Note: Inside this there's a delayed policy update.
                #       Every `update_policy_freq` it will learn `number_updates` times.
                self.learn(self.buffer.sample())

    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)

        if (self.iteration % self.update_freq) == 0:
            self._update_value_function(states, actions, rewards, next_states,
                                        dones)

        if (self.iteration % self.update_policy_freq) == 0:
            self._update_policy(states)

            soft_update(self.target_actor, self.actor, self.tau)
            soft_update(self.target_critic, self.critic, self.tau)

    def _update_value_function(self, states, actions, rewards, next_states,
                               dones):
        # critic loss
        next_actions = self.target_actor.act(next_states)
        Q_target_next = torch.min(
            *self.target_critic.act(next_states, next_actions))
        Q_target = rewards + (self.gamma * Q_target_next * (1 - dones))
        Q1_expected, Q2_expected = self.critic(states, actions)
        loss_critic = mse_loss(Q1_expected, Q_target) + mse_loss(
            Q2_expected, Q_target)

        # Minimize the loss
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(),
                                 self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

    def _update_policy(self, states):
        # Compute actor loss
        pred_actions = self.actor(states)
        loss_actor = -self.critic(states, pred_actions)[0].mean()
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = loss_actor.item()

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
Ejemplo n.º 25
0
def train(args):
    # torch.multiprocessing.set_sharing_strategy('file_system')
    # too many barriers / one node data parallel and multiple node DDP
    os.environ['MASTER_ADDR'] = args["master_addr"]
    os.environ['MASTER_PORT'] = args["master_port"]
    os.environ['TOKENIZERS_PARALLELISM'] = "true"
    torch.backends.cudnn.benchmark = True
    rank = args["nr"]
    gpus = args["gpus_per_node"]
    if args["cpu"]:
        assert args["world_size"] == 1
        device = torch.device("cpu")
        barrier = get_barrier(False)
    else:
        dist.init_process_group(args["dist_backend"], rank=rank, world_size=args["world_size"])
        device = torch.device('cuda:0')  # Unique only on individual node.
        torch.cuda.set_device(device)
        barrier = get_barrier(True)

    set_seeds(args["seed"])
    mconf = model_config.to_dict()
    config = dict(md_config=md_config, sm_config=sm_config)[mconf.pop("model_size")]
    tokenizer = get_tokenizer(mconf.pop("tokenizer_name"))
    config.vocab_size = len(tokenizer) + 22
    config.tokenizer_length = 1024
    config.tokenizer_length = config.tokenizer_length - config.num_highway_cls_tokens
    config.max_position_embeddings = config.max_position_embeddings + config.num_highway_cls_tokens

    collate_fn = get_collate_fn(config.num_highway_cls_tokens, tokenizer.pad_token_id)

    model = FastFormerForFusedELECTRAPretraining(config, tokenizer=tokenizer, **mconf).to(device)
    print("Trainable Params = %s" % (numel(model) / 1_000_000))
    if args["pretrained_model"] is not None:
        model.load_state_dict(torch.load(args["pretrained_model"], map_location={'cuda:%d' % 0: 'cuda:%d' % 0}))
    model.data_parallel = True
    # Take model to local rank
    if args["cpu"]:
        ddp_model = model
    else:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        ddp_model = DDP(model, device_ids=[0], find_unused_parameters=True)
    all_params = list(filter(lambda p: p.requires_grad, ddp_model.parameters()))
    optc = optimizer_config.to_dict()
    optimizer = AdamW(all_params, lr=optc["lr"], eps=optc["eps"], weight_decay=optc["weight_decay"], betas=(optc["beta_1"], optc["beta_2"]))
    optimizer.zero_grad()
    scaler = GradScaler()

    model_save_dir = args["model_save_dir"]
    model_save_name = args["model_save_name"]
    if rank == 0:
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
    assert os.path.exists(model_save_dir)
    barrier()
    print("Optimizer Created for Rank = %s" % rank)
    shuffle_dataset = args["shuffle_dataset"]
    sampling_fraction = optc["sampling_fraction"]
    if not args["validate_only"] and not args["test_only"]:
        train_loader = build_dataloader(args["train_dataset"], shuffle_dataset, sampling_fraction, config, collate_fn, tokenizer, world_size=args["world_size"], num_workers=args["num_workers"])

    print("Data Loaded for Rank = %s" % rank)
    validate_every_steps = args["validate_every_steps"]
    log_every_steps = args["log_every_steps"]
    save_every_steps = args["save_every_steps"]
    scheduler = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"])
    gradient_clipping = optc["gradient_clipping"]
    _ = model.train()
    barrier()

    start_time = time.time()
    batch_times = []
    model_times = []
    full_times = []
    print("Start Training for Rank = %s" % rank)
    for step, batch in enumerate(train_loader):
        model.zero_grad()
        optimizer.zero_grad()
        if step == 0:
            print("First Batch Training for Rank = %s" % rank)
        # if step <= 39:
        #     continue
        gen_batch_time = time.time() - start_time
        batch_times.append(gen_batch_time)
        if (step + 1) % save_every_steps == 0:
            if rank == 0:
                torch.save(ddp_model.state_dict(), os.path.join(model_save_dir, model_save_name))
            barrier()
        if (step + 1) % validate_every_steps == 0:
            if rank == 0:
                val_results = LargeValidator(args["validation_dataset"], ddp_model, config, device, tokenizer)()
                print("Rank = %s, steps = %s, Val = %s" % (rank, step, val_results))
            barrier()
        record_accuracy = False
        if (step + 1) % log_every_steps == 0:
            record_accuracy = True

        batch["record_accuracy"] = record_accuracy
        labels = batch["label_mlm_input_ids"] if "label_mlm_input_ids" in batch else batch["input_ids"]
        labels = labels.to(device)
        model_start_time = time.time()
        if args["cpu"]:
            output = ddp_model(**batch, labels=labels)
            output = {key: [item[key] for item in output]
                      for key in list(functools.reduce(
                    lambda x, y: x.union(y),
                    (set(dicts.keys()) for dicts in output)
                ))
                      }
            output = {k: torch.mean(v) for k, v in output.items()}
            loss = output["loss"]
            loss_dict = output["loss_dict"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(all_params, gradient_clipping)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        else:
            with autocast():

                output = ddp_model(**batch, labels=labels)
                output = {key: [item[key] for item in output]
                          for key in list(functools.reduce(
                        lambda x, y: x.union(y),
                        (set(dicts.keys()) for dicts in output)
                    ))
                          }
                output = {k: torch.mean(v) for k, v in output.items()}
                loss = output["loss"]
                loss_dict = output["loss_dict"]
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(all_params, gradient_clipping)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
        model_end_time = time.time() - model_start_time
        model_times.append(model_end_time)
        full_time = time.time() - start_time
        full_times.append(full_time)
        start_time = time.time()
        if (step + 1) % log_every_steps == 0:
            print("Rank = %s, steps = %s, batch_size = %s, Loss = %s, Accuracy = %s" % (rank, step, batch["input_ids"].size(), loss_dict, output["accuracy_hist"]))
            print("Batch time = %s, Model Time = %s, Full time = %s" % (np.mean(batch_times), np.mean(model_times), np.mean(full_times)))
            batch_times = []
            model_times = []
            full_times = []
            clean_memory()
            barrier()



    # Take inputs to local_rank

    # TODO: validate on multigpu, sort the val datasets alphabetically and let the gpu with rank == dataset rank in sort pick up the dataset. GPUs with rank > len(datasetDict) stay idle.
    # TODO: select one dataset and make full batch from it, this way rebalancing can be easy.
    # TODO: dataset rebalancing.
    # TODO: save model only in local_rank == 0 process
    # TODO: Check if all initialised model weights are same??
    # I've been tracking an ema of sample training loss during training and using that to guide weighted data sampling (rather than the typical uniform sampling). Seems to help with a variety of real world datasets where the bulk of the data is often very similar and easy to learn but certain subpopulations are much more challenging.

    pass
Ejemplo n.º 26
0
def train(args, train_dataset, model):
    """Train the model on `steps` batches"""
    logger.debug('start')

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    print('train_batch_size %d' % args.train_batch_size)

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    # Prepare optimizer and schedule (linear warmup and decay)
    # 不需要权重衰减的参数
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    bert_param_optimizer = list(model.bert.named_parameters())
    crf_param_optimizer = list(model.crf.named_parameters())
    optimizer_grouped_parameters = [
        {
            'params': [
                p for n, p in bert_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.bert_lr
        },
        {
            'params': [
                p for n, p in bert_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.bert_lr
        },
        {
            'params': [
                p for n, p in crf_param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.crf_lr
        },
        {
            'params': [
                p for n, p in crf_param_optimizer
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.crf_lr
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters)

    # args.warmup_steps = int(t_total * args.warmup_proportion)
    # scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
    #                                             num_training_steps=t_total)
    # scheduler.step()

    # Train!
    logger.info("***** Running training *****")

    global_step = 0
    for epoch in range(int(args.num_train_epochs)):

        for step, batch_data in enumerate(train_dataloader):
            # set model to training mode
            model.train()

            batch_data = tuple(t.to(args.device) for t in batch_data)
            batch_input_ids, batch_input_mask, batch_segment_ids, batch_label_ids = batch_data

            optimizer.zero_grad()

            outputs = model(input_ids=batch_input_ids,
                            attention_mask=batch_input_mask,
                            token_type_ids=batch_segment_ids,
                            labels=batch_label_ids)

            loss, scores = outputs[:2]

            loss.backward()
            optimizer.step()
            if step % 5 == 0:
                print('epoch: {} | step: {} | loss: {}'.format(
                    epoch, step, loss.item()))

            global_step += 1

    torch.save(model.state_dict(), args.modelfile_finetuned)

    return global_step
Ejemplo n.º 27
0
def train():
    """ Train the model using the parameters defined in the config file """
    print('Initialising ...')
    cfg = TrainConfig()
    checkpoint_folder = 'checkpoints/{}/'.format(cfg.experiment_name)

    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)

    tb_folder = 'tb/{}/'.format(cfg.experiment_name)
    if not os.path.exists(tb_folder):
        os.makedirs(tb_folder)

    writer = SummaryWriter(logdir=tb_folder, flush_secs=30)
    model = ParrotModel().cuda().train()
    optimiser = AdamW(model.parameters(),
                      lr=cfg.initial_lr,
                      weight_decay=cfg.weight_decay)

    train_dataset = ParrotDataset(cfg.train_labels, cfg.mp3_folder)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.batch_size,
                              num_workers=cfg.workers,
                              collate_fn=parrot_collate_function,
                              pin_memory=True)

    val_dataset = ParrotDataset(cfg.val_labels, cfg.mp3_folder)
    val_loader = DataLoader(val_dataset,
                            batch_size=cfg.batch_size,
                            num_workers=cfg.workers,
                            collate_fn=parrot_collate_function,
                            shuffle=False,
                            pin_memory=True)

    epochs = cfg.epochs
    init_loss, step = 0., 0
    avg_loss = AverageMeter()
    print('Starting training')
    for epoch in range(epochs):
        loader_length = len(train_loader)
        epoch_start = time.time()

        for batch_idx, batch in enumerate(train_loader):
            optimiser.zero_grad()

            # VRAM control by skipping long examples
            if batch['spectrograms'].shape[-1] > cfg.max_time:
                continue

            # inference
            target = batch['targets'].cuda()
            model_input = batch['spectrograms'].cuda()
            model_output = model(model_input)

            # loss
            input_lengths = batch['input_lengths'].cuda()
            target_lengths = batch['target_lengths'].cuda()
            loss = ctc_loss(model_output, target, input_lengths,
                            target_lengths)
            loss.backward()

            if epoch == 0 and batch_idx == 0:
                init_loss = loss

            # logging
            elapsed = time.time() - epoch_start
            progress = batch_idx / loader_length
            est = datetime.timedelta(
                seconds=int(elapsed / progress)) if progress > 0.001 else '-'
            avg_loss.update(loss)
            suffix = '\tloss {:.4f}/{:.4f}\tETA [{}/{}]'.format(
                avg_loss.avg, init_loss,
                datetime.timedelta(seconds=int(elapsed)), est)
            printProgressBar(batch_idx,
                             loader_length,
                             suffix=suffix,
                             prefix='Epoch [{}/{}]\tStep [{}/{}]'.format(
                                 epoch, epochs, batch_idx, loader_length))

            writer.add_scalar('Steps/train_loss', loss, step)

            # saving the model
            if step % cfg.checkpoint_every == 0:
                test_name = '{}/test_epoch{}.mp3'.format(
                    checkpoint_folder, epoch)
                test_mp3_file(cfg.test_mp3, model, test_name)
                checkpoint_name = '{}/epoch_{}.pth'.format(
                    checkpoint_folder, epoch)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'epoch': epoch,
                        'batch_idx': loader_length,
                        'step': step,
                        'optimiser': optimiser.state_dict()
                    }, checkpoint_name)

            # validating
            if step % cfg.val_every == 0:
                val(model, val_loader, writer, step)
                model = model.train()

            step += 1
            optimiser.step()

        # end of epoch
        print('')
        writer.add_scalar('Epochs/train_loss', avg_loss.avg, epoch)
        avg_loss.reset()
        test_name = '{}/test_epoch{}.mp3'.format(checkpoint_folder, epoch)
        test_mp3_file(cfg.test_mp3, model, test_name)
        checkpoint_name = '{}/epoch_{}.pth'.format(checkpoint_folder, epoch)
        torch.save(
            {
                'model': model.state_dict(),
                'epoch': epoch,
                'batch_idx': loader_length,
                'step': step,
                'optimiser': optimiser.state_dict()
            }, checkpoint_name)

    # finished training
    writer.close()
    print('Training finished :)')
Ejemplo n.º 28
0
class Trainer : 

    def __init__(self ,  eval_df , train_df, max_length , batch_size , n_class , name_model  ) : 
        self.model = sentiment_analysis(n_class ,name_model ) 
        self.tokenizer = self.model.tokenizer
        self.eval_df = eval_df
        self.train_df = train_df
        self.max_length = max_length 
        self.batch_size = batch_size
        self.eval_dataloader = create_data_loader(self.eval_df, self.tokenizer, max_len = self.max_length , batch_size = self.batch_size) 
        self.train_dataloader = create_data_loader(self.train_df, self.tokenizer, max_len = self.max_length , batch_size = self.batch_size) 
        self.optimizer = AdamW(self.model.parameters(), lr=2e-5)
        self.loss_fn =  nn.CrossEntropyLoss()
    


    def train_epoch(self) :
        losses = [] 
        train_correct_predictions = []
        self.model.train() 
        for data in self.train_dataloader : 
            input_ids = data['input_ids'] 
            attention_mask = data['attention_mask'] 
            targets = data['targets']
            outputs = self.model(input_ids = input_ids , attention_mask = attention_mask )
            preds = torch.max(outputs , dim = -1)[1]
            #Calculate metrics
            loss = self.loss_fn(outputs , targets)
            losses.append(loss)
            for i in range(len(preds)) :
                if preds[i] == targets[i] :
                    train_correct_predictions.append(1)
                else : 
                    train_correct_predictions.append(0)
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
        return sum(train_correct_predictions) / float(len(train_correct_predictions)) , sum(losses) / float(len(losses))



    def eval_model(self) : 
        self.model.eval()
        losses_eval = []
        correct_predicitons = []

        with torch.no_grad():
    
            for d in self.eval_dataloader : 
                input_ids = d['input_ids'] 
                attention_mask = d['attention_mask']
                targets = d['targets']
                outputs = self.model( input_ids = input_ids , 
                        attention_mask = attention_mask) 
                _ , preds =  torch.max(outputs , dim = 1) 
                loss = self.loss_fn(outputs , targets )
                losses_eval.append(loss.item()) 
                for i in range(len(preds)) :
                    if preds[i] == targets[i] :
                        correct_predicitons.append(1)
                    else : 
                        correct_predicitons.append(0)
        
        return sum(correct_predicitons) / float(len(correct_predicitons)) , sum(losses_eval) / float(len(losses_eval))



    def train (self , EPOCHS) : 
        
        best_accuracy = 0
        history = defaultdict(list)
        total_steps = len(self.train_dataloader) * EPOCHS
        print(total_steps)
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer  , num_warmup_steps=0, num_training_steps=total_steps)

        for number_epochs in range(EPOCHS) : 
            train_acc, train_loss = self.train_epoch ()

            print(f'Train loss {train_loss} accuracy {train_acc}')

            val_acc , val_loss = self.eval_model()

            print(f'eval loss {val_loss} accuracy {val_acc}')

            history['train_acc'].append(train_acc)
            history['train_loss'].append(train_loss)
            history['val_acc'].append(val_acc)
            history['val_loss'].append(val_loss)

            if val_acc > best_accuracy:
                self.model.save('best_model_state')
                best_accuracy = val_acc
Ejemplo n.º 29
0
        # x_df = x_df.style.background_gradient(cmap='Greys', axis=None, subset=slice(0,10))

        placeholders_[0][0].write(x_df)

        y_df = pd.DataFrame(data=y.detach().numpy())
        y_df = y_df.style.background_gradient(cmap='Greys', axis=None)
        placeholders_[1][0].write(y_df)
        output = net(x.flatten()).reshape((3, 4))
        loss = criterion(output, y)

        out_df = pd.DataFrame(data=output.detach().numpy())
        out_df = out_df.style.background_gradient(cmap='Greys', axis=None)
        placeholders_[2][0].write(out_df)
        print(f'Loss: {loss.detach()}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        params = net.parameters()
        # print(list(enumerate(params)))
        for i, param in enumerate(params):
            if i == 0:
                p_df = pd.DataFrame(data=param.reshape(-1, 3).detach().numpy())
                p_df = p_df.style.background_gradient(cmap='Greys', axis=None)
                placeholders[i][0].write(p_df)
            else:
                placeholders[i][0].write(param.detach().numpy())
        # print(params)
        # exit()
        # placeholders[0][0].write()
    def train(self,
              model,
              epochs_num=1,
              train_dataset=None,
              validation_dataset=None,
              data_collator=None,
              parent_information=None,
              lr=0.01,
              batch_size=64,
              weight_decay=0.01,
              betas=(0.9, 0.999),
              evaluate_steps=40,
              has_parent=True,
              verbose=False):
        '''
        Train the model given with the dataset provided. Will run evaluation on the validation set every
        `evaluate_steps` training steps, and at the end of each epoch.

        Args:
          model: instantiated model to train
          epochs_num: Number of epochs to train
          train_dataset: Train dataset
          validation_dataset: Validation dataset
          data_collator: A data collator function that when called will collate the data, passed to Dataloader
          parent_information:
          lr: Learning rate to use in the Opimizer
          batch_size: Batch size to use
          weight_decay: Optimizer wieght decay
          betas: Betas used in the Optimizer
          evaluate_steps: How many training steps
          verbose: If true the training loss and addition f1 scores will be printed every step
        Returns:
          f1: double, the resulting mean f1 score of all the labels (it will be a number between 0 and 1)
          precision: double, the resulting mean precision of all the labels (it will be a number between 0 and 1)
          recall:
        '''

        self.model = model

        # Prints additional loss and metrics information during training if set to true
        self.verbose = verbose

        # Set timers
        start = time.time()
        remaining_time = 0

        # Get dataloader
        self.data_collator = data_collator
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=batch_size,
                                      collate_fn=self.data_collator,
                                      shuffle=True)
        # Default optimizer
        optimizer = AdamW(model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay,
                          betas=betas)

        mb = master_bar(range(epochs_num))
        pb = progress_bar(train_dataloader, parent=mb)
        for epoch in mb:
            for i_batch, sample_batched in enumerate(pb):
                self.model.train()

                # Get input
                x = sample_batched[0].to(self.device)

                #if i_batch == 0:
                #print()
                #print(x.size())

                # Get targets (labels)
                target = sample_batched[1].float().to(self.device)

                if has_parent and (len(sample_batched) == 3):
                    parent_labels = sample_batched[2].float().to(self.device)

                    if self.device == 'cuda':
                        self.model.cuda(0)
                        x = x.cuda(0)
                        parent_labels = parent_labels.cuda(0)
                        target = target.cuda(0)
                    else:
                        self.model.cpu()
                        x = x.cpu()
                        parent_labels = parent_labels.cpu()
                        target = target.cpu()

                    # Pass input to model
                    output = self.model(x, parent_labels)
                else:
                    if self.device == 'cuda':
                        self.model.cuda(0)
                        x = x.cuda(0)
                        target = target.cuda(0)
                    else:
                        self.model.cpu()
                        x = x.cpu()
                        target = target.cpu()

                    # Pass input to model
                    output = self.model(x)

                # Loss
                train_loss = self.criterion(output, target)

                if self.verbose:
                    print(f'train_loss: {train_loss}')

                # Do backward, do step and zero gradients
                train_loss.backward()
                optimizer.step()
                model.zero_grad()
                optimizer.zero_grad()

                # Evaluate
                if (i_batch > 0) and (i_batch % evaluate_steps) == 0:
                    #print('\nevaluating...')
                    _ = self.evaluate(self.model, validation_dataset)

                self.train_losses.append(train_loss.item())

            # Run evaluation at the end of each epoch and return validation outputs
            #print('\nEnd of epoch evaluation results:')
            validation_outputs = self.evaluate(model, validation_dataset)
            y_hat_validation, validation_labels_child, validation_labels_parent = validation_outputs

            # Print out progress stats
            end = time.time()
            remaining_time = remaining_time * 0.90 + (
                (end - start) * (epochs_num - epoch + 1) / (epoch + 1)) * 0.1
            remaining_time_corrected = remaining_time / (1 -
                                                         (0.9**(epoch + 1)))
            epoch_str = "last epoch finished: " + str(epoch + 1)
            progress_str = "progress: " + str(
                (epoch + 1) * 100 / epochs_num) + "%"
            time_str = "time: " + str(remaining_time_corrected / 60) + " mins"
            sys.stdout.write("\r" + epoch_str + " -- " + progress_str +
                             " -- " + time_str)
            sys.stdout.flush()

            self.epochs.append(epoch)

        print("\n" + "Training completed. Total training time: " +
              str(round((end - start) / 60, 2)) + " mins")
        return (y_hat_validation, validation_labels_child,
                validation_labels_parent, self.train_losses,
                self.validation_losses, self.f1_scores_validations,
                self.precisions_validations, self.recalls_validations)