Exemplo n.º 1
0
    def predict(self, args, eval_features, C_eval_true, topk=10, verbose=True):
        '''Prediction interface'''
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_output_ids = torch.tensor([f.output_ids for f in eval_features], dtype=torch.long)
        all_output_mask = torch.tensor([f.output_mask for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_output_ids, all_output_mask)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        self.model.eval()
        total_loss = 0.
        total_example = 0.
        rows, cols, vals = [], [], []
        for batch in eval_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, output_ids, output_mask = batch
            cur_batch_size = input_ids.size(0)

            with torch.no_grad():
                c_pred = self.model(input_ids, segment_ids, input_mask)
                c_true = data_utils.repack_output(output_ids, output_mask, self.num_clusters, device)
                loss = self.criterion(c_pred, c_true)
            total_loss += cur_batch_size * loss

            # get topk prediction rows,cols,vals
            cpred_topk_vals, cpred_topk_cols = c_pred.topk(topk, dim=1)
            cpred_topk_rows = (total_example + torch.arange(cur_batch_size))
            cpred_topk_rows = cpred_topk_rows.view(cur_batch_size, 1).expand_as(cpred_topk_cols)
            total_example += cur_batch_size

            # append
            rows += cpred_topk_rows.numpy().flatten().tolist()
            cols += cpred_topk_cols.cpu().numpy().flatten().tolist()
            vals += cpred_topk_vals.cpu().numpy().flatten().tolist()

        eval_loss = total_loss / total_example
        m = int(total_example)
        n = self.num_clusters
        pred_csr_codes = smat.csr_matrix( (vals, (rows,cols)), shape=(m,n) )
        pred_csr_codes = rf_util.smat_util.sorted_csr(pred_csr_codes, only_topk=None)
        C_eval_pred = pred_csr_codes

        # evaluation
        eval_metrics = rf_linear.Metrics.generate(C_eval_true, C_eval_pred, topk=args.only_topk)
        if verbose:
            logger.info('| matcher_eval_prec {}'.format(' '.join("{:4.2f}".format(100*v) for v in eval_metrics.prec)))
            logger.info('| matcher_eval_recl {}'.format(' '.join("{:4.2f}".format(100*v) for v in eval_metrics.recall)))
            logger.info('-' * 89)
        return eval_loss, eval_metrics, C_eval_pred
Exemplo n.º 2
0
    def train(self, args, trn_features, eval_features=None, C_eval=None):
        # Prepare optimizer
        num_train_optimization_steps = int(
            len(trn_features) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)
        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

        # Start Batch Training
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(trn_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0
        all_input_ids = torch.tensor([f.input_ids for f in trn_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in trn_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in trn_features],
                                       dtype=torch.long)
        all_output_ids = torch.tensor([f.output_ids for f in trn_features],
                                      dtype=torch.long)
        all_output_mask = torch.tensor([f.output_mask for f in trn_features],
                                       dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_output_ids,
                                   all_output_mask)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        self.model.train()
        total_run_time = 0.0
        best_matcher_prec = -1
        for epoch in range(1, args.num_train_epochs):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                start_time = time.time()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, output_ids, output_mask = batch
                c_pred = self.model(input_ids, segment_ids, input_mask)
                c_true = data_utils.repack_output(output_ids, output_mask,
                                                  self.num_clusters, device)
                loss = self.criterion(c_pred, c_true)

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                total_run_time += time.time() - start_time
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # print training log
                if step % args.log_interval == 0 and step > 0:
                    elapsed = time.time() - start_time
                    cur_loss = tr_loss / nb_tr_steps
                    logger.info(
                        "| epoch {:3d} | {:4d}/{:4d} batches | ms/batch {:5.4f} | train_loss {:e}"
                        .format(epoch, step, len(train_dataloader),
                                elapsed * 1000 / args.log_interval, cur_loss))

                # eval on dev set and save best model
                if step % args.eval_interval == 0 and step > 0 and args.stop_by_dev:
                    eval_loss, eval_metrics, C_eval_pred = self.predict(
                        args,
                        eval_features,
                        C_eval,
                        topk=args.only_topk,
                        verbose=False)
                    logger.info('-' * 89)
                    logger.info(
                        '| epoch {:3d} evaluation | time: {:5.4f}s | eval_loss {:e}'
                        .format(epoch, total_run_time, eval_loss))
                    logger.info('| matcher_eval_prec {}'.format(' '.join(
                        "{:4.2f}".format(100 * v) for v in eval_metrics.prec)))
                    logger.info('| matcher_eval_recl {}'.format(' '.join(
                        "{:4.2f}".format(100 * v)
                        for v in eval_metrics.recall)))

                    avg_matcher_prec = np.mean(eval_metrics.prec)
                    if avg_matcher_prec > best_matcher_prec and epoch > 0:
                        logger.info(
                            '| **** saving model at global_step {} ****'.
                            format(global_step))
                        best_matcher_prec = avg_matcher_prec
                        self.save(args)
                    logger.info('-' * 89)
                    self.model.train(
                    )  # after model.eval(), reset model.train()

        return self
Exemplo n.º 3
0
  def train(self, args, trn_features, eval_features=None, C_eval=None):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    all_input_ids = torch.tensor([f.input_ids for f in trn_features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in trn_features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in trn_features], dtype=torch.long)
    all_output_ids = torch.tensor([f.output_ids for f in trn_features], dtype=torch.long)
    all_output_mask = torch.tensor([f.output_mask for f in trn_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_output_ids, all_output_mask)
    train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
      t_total = args.max_steps
      args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
      t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
      {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
      {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    if args.fp16:
      try:
        from apex import amp
      except ImportError:
        raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
      self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
      self.model = torch.nn.DataParallel(self.model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
      self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank],
                                                             output_device=args.local_rank,
                                                             find_unused_parameters=True)

    # Start Batch Training
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(trn_features))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    total_run_time = 0.0
    best_matcher_prec = -1

    self.model.zero_grad()
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)
    for epoch in range(1, int(args.num_train_epochs)):
      for step, batch in enumerate(train_dataloader):
        self.model.train()
        start_time = time.time()
        batch = tuple(t.to(args.device) for t in batch)
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'output_ids':     batch[3],
                  'output_mask':    batch[4],
                  'labels':         None}
        if args.model_type != 'distilbert':
          inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids

        outputs = self.model(input_ids=inputs['input_ids'],
                             attention_mask=inputs['attention_mask'],
                             token_type_ids=inputs['token_type_ids'],
                             labels=None)
        c_pred = outputs[0] # if labels=None, then output[0] = logits
        c_true = data_utils.repack_output(inputs['output_ids'], inputs['output_mask'],
                                          self.num_clusters, args.device)
        loss = self.criterion(c_pred, c_true)

        if args.n_gpu > 1:
          loss = loss.mean() # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
          loss = loss / args.gradient_accumulation_steps

        if args.fp16:
          with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        else:
          loss.backward()

        tr_loss += loss.item()
        total_run_time += time.time() - start_time
        if (step + 1) % args.gradient_accumulation_steps == 0:
          if args.fp16:
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
          else:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.max_grad_norm)

          optimizer.step()
          scheduler.step()  # Update learning rate schedule
          optimizer.zero_grad()
          global_step += 1

          if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
            # print training log
            elapsed = time.time() - start_time
            cur_loss = (tr_loss - logging_loss) / args.logging_steps
            cur_lr = scheduler.get_lr()[0]
            logger.info("| [{:4d}/{:4d}][{:6d}/{:6d}] | {:4d}/{:4d} batches | ms/batch {:5.4f} | train_loss {:6e} | lr {:.6e}".format(
              int(epoch), int(args.num_train_epochs),
              int(global_step), int(t_total),
              int(step), len(train_dataloader),
              elapsed * 1000. / args.logging_steps, cur_loss, cur_lr))
            logging_loss = tr_loss

          if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
            # eval on dev set and save best model
            eval_loss, eval_metrics, C_eval_pred = self.predict(args, eval_features, C_eval, topk=args.only_topk, verbose=False)
            logger.info('-' * 89)
            logger.info('| epoch {:3d} step {:6d} evaluation | time: {:5.4f}s | eval_loss {:e}'.format(
              epoch, global_step, total_run_time, eval_loss))
            logger.info('| matcher_eval_prec {}'.format(' '.join("{:4.2f}".format(100*v) for v in eval_metrics.prec)))
            logger.info('| matcher_eval_recl {}'.format(' '.join("{:4.2f}".format(100*v) for v in eval_metrics.recall)))

            avg_matcher_prec = np.mean(eval_metrics.prec)
            if avg_matcher_prec > best_matcher_prec and epoch > 0:
              logger.info('| **** saving model at global_step {} ****'.format(global_step))
              best_matcher_prec = avg_matcher_prec
              self.save_model(args)
              logger.info('-' * 89)

        if args.max_steps > 0 and global_step > args.max_steps:
          break
      if args.max_steps > 0 and global_step > args.max_steps:
        train_iterator.close()

    return self
Exemplo n.º 4
0
  def predict(self, args, eval_features, C_eval_true, topk=10, verbose=True):
    """Prediction interface"""
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in eval_features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in eval_features], dtype=torch.long)
    all_output_ids = torch.tensor([f.output_ids for f in eval_features], dtype=torch.long)
    all_output_mask = torch.tensor([f.output_mask for f in eval_features], dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_output_ids, all_output_mask)

    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_features))
    logger.info("  Batch size = %d", args.eval_batch_size)

    total_loss = 0.
    total_example = 0.
    rows, cols, vals = [], [], []
    for batch in eval_dataloader:
      self.model.eval()
      batch = tuple(t.to(args.device) for t in batch)

      with torch.no_grad():
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'output_ids':     batch[3],
                  'output_mask':    batch[4],
                  'labels':         None}
        if args.model_type != 'distilbert':
          inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
        cur_batch_size = inputs['input_ids'].size(0)

        outputs = self.model(input_ids=inputs['input_ids'],
                             attention_mask=inputs['attention_mask'],
                             token_type_ids=inputs['token_type_ids'],
                             labels=None)
        c_pred = outputs[0]
        c_true = data_utils.repack_output(inputs['output_ids'], inputs['output_mask'],
                                          self.num_clusters, args.device)
        loss = self.criterion(c_pred, c_true)
        total_loss += cur_batch_size * loss

      # get topk prediction rows,cols,vals
      cpred_topk_vals, cpred_topk_cols = c_pred.topk(topk, dim=1)
      cpred_topk_rows = (total_example + torch.arange(cur_batch_size))
      cpred_topk_rows = cpred_topk_rows.view(cur_batch_size, 1).expand_as(cpred_topk_cols)
      total_example += cur_batch_size

      # append
      rows += cpred_topk_rows.numpy().flatten().tolist()
      cols += cpred_topk_cols.cpu().numpy().flatten().tolist()
      vals += cpred_topk_vals.cpu().numpy().flatten().tolist()

    eval_loss = total_loss / total_example
    m = int(total_example)
    n = self.num_clusters
    pred_csr_codes = smat.csr_matrix( (vals, (rows,cols)), shape=(m,n) )
    pred_csr_codes = rf_util.smat_util.sorted_csr(pred_csr_codes, only_topk=None)
    C_eval_pred = pred_csr_codes

    # evaluation
    eval_metrics = rf_linear.Metrics.generate(C_eval_true, C_eval_pred, topk=args.only_topk)
    if verbose:
      logger.info('| matcher_eval_prec {}'.format(' '.join("{:4.2f}".format(100*v) for v in eval_metrics.prec)))
      logger.info('| matcher_eval_recl {}'.format(' '.join("{:4.2f}".format(100*v) for v in eval_metrics.recall)))
      logger.info('-' * 89)

    return eval_loss, eval_metrics, C_eval_pred
Exemplo n.º 5
0
  def predict(self, args, eval_features, C_eval_true, topk=10, get_hidden=False):
    """Prediction interface"""
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in eval_features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in eval_features], dtype=torch.long)
    all_output_ids = torch.tensor([f.output_ids for f in eval_features], dtype=torch.long)
    all_output_mask = torch.tensor([f.output_mask for f in eval_features], dtype=torch.long)
    eval_data = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_output_ids, all_output_mask)

    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=4)

    # multi-gpu eval
    if args.n_gpu > 1 and not isinstance(self.model, torch.nn.DataParallel):
      self.model = torch.nn.DataParallel(self.model)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_features))
    logger.info("  Batch size = %d", args.eval_batch_size)

    total_loss = 0.
    total_example = 0.
    rows, cols, vals = [], [], []
    all_pooled_output = []
    for batch in eval_dataloader:
      self.model.eval()
      batch = tuple(t.to(args.device) for t in batch)

      with torch.no_grad():
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'output_ids':     batch[3],
                  'output_mask':    batch[4],
                  'labels':         None}
        if args.model_type != 'distilbert':
          inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids
        cur_batch_size = inputs['input_ids'].size(0)

        # forward
        outputs = self.model(input_ids=inputs['input_ids'],
                             attention_mask=inputs['attention_mask'],
                             token_type_ids=inputs['token_type_ids'],
                             labels=None)
        if get_hidden and self.config.output_hidden_states:
          c_pred, hidden_states = outputs[0], outputs[1]
        else:
          c_pred = outputs[0]

        # get pooled_output, which is the [CLS] embedding for the document
        # assume self.model hasattr module because torch.nn.DataParallel
        if get_hidden:
          if args.model_type == 'bert':
            pooled_output = self.model.module.bert.pooler(hidden_states[-1])
            pooled_output = self.model.module.dropout(pooled_output)
            #logits = self.model.classifier(pooled_output)
          elif args.model_type == 'roberta':
            pooled_output = self.model.module.classifier.dropout(hidden_states[-1][:,0,:])
            pooled_output = self.model.module.classifier.dense(pooled_output)
            pooled_output = torch.tanh(pooled_output)
            pooled_output = self.model.module.classifier.dropout(pooled_output)
            #logits = self.model.classifier.out_proj(pooled_output)
          elif args.model_type == 'xlnet':
            pooled_output = self.model.module.sequence_summary(hidden_states[-1])
            #logits = self.model.logits_proj(pooled_output)
          else:
            raise NotImplementedError("unknown args.model_type {}".format(args.model_type))
          all_pooled_output.append(pooled_output.cpu().numpy())

        # get ground true cluster ids
        c_true = data_utils.repack_output(inputs['output_ids'], inputs['output_mask'],
                                          self.num_clusters, args.device)
        loss = self.criterion(c_pred, c_true)
        total_loss += cur_batch_size * loss

      # get topk prediction rows,cols,vals
      cpred_topk_vals, cpred_topk_cols = c_pred.topk(topk, dim=1)
      cpred_topk_rows = (total_example + torch.arange(cur_batch_size))
      cpred_topk_rows = cpred_topk_rows.view(cur_batch_size, 1).expand_as(cpred_topk_cols)
      total_example += cur_batch_size

      # append
      rows += cpred_topk_rows.numpy().flatten().tolist()
      cols += cpred_topk_cols.cpu().numpy().flatten().tolist()
      vals += cpred_topk_vals.cpu().numpy().flatten().tolist()

    eval_loss = total_loss / total_example
    m = int(total_example)
    n = self.num_clusters
    pred_csr_codes = smat.csr_matrix( (vals, (rows,cols)), shape=(m,n) )
    pred_csr_codes = rf_util.smat_util.sorted_csr(pred_csr_codes, only_topk=None)
    C_eval_pred = pred_csr_codes

    # evaluation
    eval_metrics = rf_linear.Metrics.generate(C_eval_true, C_eval_pred, topk=args.only_topk)
    if get_hidden:
      eval_embeddings = np.concatenate(all_pooled_output, axis=0)
    else:
      eval_embeddings = None
    return eval_loss, eval_metrics, C_eval_pred, eval_embeddings