def train(rank,
          args,
          tokenizer,
          train_dataset,
          test_dataset,
          model_s,
          model_t,
          params_to_tune,
          head_importance=None,
          loss_num=-1,
          tune_iter=0):
    """ Train the model """
    global train_count
    train_count += 1

    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train stage: ", train_count)
        printlog(model_s)

    if head_importance is not None:
        head_mask = torch.ones(*list(head_importance.shape)).to(args.device)
        head_mask.requires_grad_(requires_grad=True)
    else:
        head_mask = None

    num_train_epochs = args.num_train_epochs
    if loss_num > 0:
        num_train_epochs = 0.25  #short train for incremental loss

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size

    #get total batch size and
    if tune_iter > 0 and args.total_train_batch_size_for_tune:
        total_train_batch_size = args.total_train_batch_size_for_tune
    else:
        total_train_batch_size = args.total_train_batch_size
    gradient_accumulation_steps = total_train_batch_size // train_batch_size

    if tune_iter > 0 and args.learning_rate_for_tune:
        learning_rate = args.learning_rate_for_tune
    else:
        learning_rate = args.learning_rate

    if check_model_type(model_s, BertModelEMB):
        #use 2 datasets for embedding question and context separatly
        if rank in [-1, 0]:
            printlog("dataset_q size", len(train_dataset.q_dataset))
            printlog("dataset_c size", len(train_dataset.c_dataset))
        datasets = [train_dataset.q_dataset, train_dataset.c_dataset]
    else:
        if rank in [-1, 0]:
            printlog("dataset size", len(train_dataset))
        datasets = [train_dataset]

    if rank > -1:
        #for distributed train use sample that take only part of samples for each process
        train_dataloaders = [
            DataLoader(dataset,
                       sampler=torch.utils.data.distributed.DistributedSampler(
                           dataset, rank=rank),
                       batch_size=per_gpu_train_batch_size)
            for dataset in datasets
        ]
    else:
        train_dataloaders = [
            DataLoader(dataset,
                       sampler=RandomSampler(dataset),
                       batch_size=train_batch_size,
                       num_workers=4) for dataset in datasets
        ]

    steps_per_epoch = sum(len(d) for d in train_dataloaders)
    steps_total = int(steps_per_epoch // gradient_accumulation_steps *
                      num_train_epochs)

    # Prepare optimizer and scheduler
    name_set = set()
    for n, p in model_s.named_parameters():
        if any(p is pp for pp in params_to_tune):
            name_set.add(n)
    named_params = [(n, p) for n, p in model_s.named_parameters()
                    if n in name_set]

    if rank in [-1, 0]:
        for n, p in named_params:
            printlog('param for tune', n)

    def new_optimizer():
        return AdamW([p for n, p in named_params],
                     lr=learning_rate,
                     eps=1e-08,
                     weight_decay=0.0)

    optimizer = new_optimizer()

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        warmup = 0.01
        if p < warmup:
            return p / warmup
        p = (p - warmup) / (1 - warmup)
        return 1 if tune_iter == 0 else max(1 - p, 0)

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        printlog("epoches", num_train_epochs)
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size)
        printlog("n_gpu", args.n_gpu)
        printlog("world_size", world_size)
        printlog("gradient_accumulation_steps", gradient_accumulation_steps)
        printlog("total train batch size",
                 train_batch_size * gradient_accumulation_steps)
        printlog("steps_total", steps_total)

    restore_count = 0
    if rank in [-1, 0]:
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
    restore_file = os.path.join(args.output_dir, 'last_good_state.pth')
    restore_loss = None

    losses_list = []

    global_step = 0
    for epoch in range(math.ceil(num_train_epochs)):
        switch_to_train(rank, model_t)
        switch_to_train(rank, model_s)
        model_s.zero_grad()
        utils.sync_models(rank, model_s)

        time_last = time.time()
        for train_dataloader in train_dataloaders:
            printlog("rank", rank, "len(train_dataloader)",
                     len(train_dataloader))
            if rank > -1:
                train_dataloader.sampler.set_epoch(epoch)

            if len(train_dataloaders) > 1:
                # reset last loss to avoid restore due to dataset changing
                printlog("rank", rank, "reset restore_loss")
                restore_loss = None

            for step, batch in enumerate(train_dataloader):
                epoch_fp = epoch + step / len(train_dataloader)
                if epoch_fp > num_train_epochs:
                    break

                inputs = {
                    'input_ids': batch[0].to(args.device),
                    'attention_mask': batch[1].to(args.device),
                    'token_type_ids': batch[2].to(args.device)
                }

                outputs_s = model_s(**inputs,
                                    head_mask=head_mask,
                                    output_hidden_states=True)
                losses = []

                with torch.no_grad():
                    outputs_t = model_t(**inputs, output_hidden_states=True)

                out_s, out_t = outputs_s[-1], outputs_t[-1]

                assert len(
                    out_s
                ) == model_s.config.num_hidden_layers + 1, "can not find hidden states in student model outputs"
                assert len(
                    out_t
                ) == model_t.config.num_hidden_layers + 1, "can not find hidden states in teacher model outputs"
                if len(out_s) != len(out_t):
                    #the student and teacher outputs are not aligned. try to find teacher output for each student output
                    n_s, n_t = len(out_s), len(out_t)
                    out_t = [
                        out_t[(i * (n_t - 1)) // (n_s - 1)] for i in range(n_s)
                    ]
                assert len(out_s) == len(
                    out_t
                ), "can not align number of outputs between student and teacher"
                assert all(
                    s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)
                ), "output shapes for student and teacher are not the same"

                out_pairs = list(zip(out_s, out_t))
                if loss_num > 0:
                    out_pairs = out_pairs[:loss_num]

                losses += [(s - t.detach()).pow(2).mean()
                           for s, t in out_pairs]

                losses_list.append([l.item() for l in losses])

                if tune_iter == 0:
                    loss = sum(losses) / len(losses)
                else:
                    weights = [
                        args.loss_weight_alpha**i for i in range(len(losses))
                    ]
                    losses_w = [w * l for w, l in zip(weights, losses)]
                    loss = sum(losses_w) / sum(weights)

                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()
                del out_s
                del out_t
                del outputs_s
                del outputs_t

                if head_importance is not None:
                    #collect gradient statistics to find most valuable heads
                    head_mask.grad.detach_()
                    head_importance += (head_mask.grad.abs().detach() -
                                        head_importance) * 0.001
                    head_mask.grad.zero_()

                if (step + 1) % gradient_accumulation_steps == 0:
                    global_step += 1

                    #sync gradients before calc step
                    utils.sync_grads(rank, named_params, global_step == 1)

                    torch.nn.utils.clip_grad_norm_(
                        [p for n, p in named_params], 1)
                    optimizer.step()
                    scheduler.step()

                    model_s.zero_grad()

                    if (step + 1) % 50 == 0:
                        str_out = "{} ep {:.2f} lrp {:.2f} rc {:02}".format(
                            train_count, epoch_fp,
                            np.log10(scheduler.get_last_lr()[0]),
                            restore_count)
                        ll = np.array(losses_list).mean(0)

                        if rank > -1:
                            #sync indicators
                            llt = torch.tensor(ll).to(args.device)
                            torch.distributed.all_reduce(
                                llt, op=torch.distributed.ReduceOp.SUM)
                            ll = llt.cpu().numpy() / float(world_size)

                        loss = ll.mean()
                        str_out += " loss {:.4f}".format(loss)
                        losses_txt = ["{:.3f}".format(l) for l in ll]
                        if tune_iter > 0:
                            losses_txt = [
                                "{:.2f}x".format(w) + lt
                                for w, lt in zip(weights, losses_txt)
                            ]
                        str_out += " ll " + " ".join(losses_txt)

                        if time_last:
                            dt_iter = (time.time() -
                                       time_last) / len(losses_list)
                            dt_ep = dt_iter * steps_per_epoch
                            str_out += " it {:.1f}s".format(dt_iter)
                            str_out += " ep {:.1f}m".format(dt_ep / (60))
                            str_out += " eta {:.1f}h".format(
                                dt_ep * (num_train_epochs - epoch_fp) /
                                (60 * 60))
                        losses_list = []
                        time_last = time.time()
                        if rank in [-1, 0]:
                            logger.info(str_out)

                        if rank > -1:
                            #sync losses
                            loss_tensor = torch.tensor([loss],
                                                       device=args.device)
                            torch.distributed.all_reduce(
                                loss_tensor, op=torch.distributed.ReduceOp.SUM)
                            loss = loss_tensor.item() / world_size

                        if restore_loss is None or loss < restore_loss * 1.5:
                            #good result lets save it
                            restore_loss = loss

                            if rank in [-1, 0]:
                                torch.save(
                                    {
                                        'model_state_dict':
                                        model_s.state_dict(),
                                        'optimizer_state_dict':
                                        optimizer.state_dict()
                                    }, restore_file)
                            if rank > -1:
                                torch.distributed.barrier()
                        else:
                            #bad result lets restore
                            restore_count += 1
                            logger.info(
                                "rank {} restore #{} from {} with {} loss".
                                format(rank, restore_count, restore_file,
                                       restore_loss))
                            checkpoint = torch.load(restore_file)
                            model_s.load_state_dict(
                                checkpoint['model_state_dict'])
                            #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                            optimizer = new_optimizer()
                            switch_to_train(rank, model_s)

        if loss_num <= 0:
            if rank in [-1, 0]:
                check_point_name = 'checkpoint-{:02}'.format(train_count)
                save_model(args, model_s, tokenizer, check_point_name)
                check_point_name = check_point_name + '-{:02}'.format(epoch +
                                                                      1)
                switch_to_eval(rank, model_s)
                result_s = evaluate(args, model_s, test_dataset)
                for k, v in result_s.items():
                    logger.info("{} {} {}".format(check_point_name, k, v))
            if rank > -1:
                torch.distributed.barrier()

    if rank in [-1, 0]:
        if os.path.exists(restore_file):
            os.remove(restore_file)
def process(rank, args, port):
    #init multiprocess
    if rank < 0:
        args.device = torch.device("cpu" if args.n_gpu < 1 else "cuda")
    else:
        # create default process group
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = str(port)
        torch.distributed.init_process_group("nccl",
                                             rank=rank,
                                             world_size=args.n_gpu)
        args.device = torch.device("cuda:{}".format(rank))
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed_all(args.seed)

    #set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if rank > 0:
        #wait while 0 process load models
        torch.distributed.barrier()

    printlog("rank", rank, "load tokenizer", args.model_teacher)
    tokenizer = BertTokenizer.from_pretrained(args.model_student)

    config = AutoConfig.from_pretrained(args.model_student)

    if hasattr(config, 'pack_cfg') and 'base_class_name' in config.pack_cfg:
        #get model class from pach_cfg
        base_class_name = config.pack_cfg['base_class_name']
        printlog("rank", rank, "base_class_name to pack", base_class_name)
        Model = globals()[base_class_name]
    else:
        #get model class from architectures filed of config
        if config.architectures:
            assert len(
                config.architectures
            ) == 1, "only single model is supported but {} has {}".format(
                args.model_student, config.architectures)
            Model = globals()[config.architectures[0]]
        else:
            Model = BertForQuestionAnswering

    printlog(
        "rank", rank,
        "load teacher {} model from {}".format(Model.__name__,
                                               args.model_teacher))
    model_t = Model.from_pretrained(args.model_teacher)

    printlog(
        "rank", rank,
        "load student {} model from {}".format(Model.__name__,
                                               args.model_student))
    model_s = BertPacked(Model).from_pretrained(args.model_student)

    if rank == 0:
        #release other process waiting
        torch.distributed.barrier()

    if rank > -1:
        #sync processes
        torch.distributed.barrier()

    params_packed = []
    if hasattr(model_s.config, 'pack_cfg'):
        logger.warning("rank {} !!!model already packed!!!".format(rank))
        logger.warning(
            "rank {} !!!just continue distill the already packed model!!!".
            format(rank))
    else:
        pack_cfg = dict([t.split(':') for t in args.pack_cfg.split(',')])
        pack_cfg['pack_emb'] = True if eval(pack_cfg['pack_emb']) else False
        printlog("rank", rank, "pack model by", pack_cfg)
        params_packed = model_s.pack_(pack_cfg)

    model_s.to(args.device)
    model_t.to(args.device)

    utils.sync_models(rank, model_s)
    if rank in [-1, 0]:
        save_model(args, model_s, tokenizer)

    def wrap_dropout(net):
        #remove dropout
        class PASS(torch.nn.Module):
            def __init__(self, dropout):
                super().__init__()
                self.dropout = dropout
                self.dropout_enable = False

            def forward(self, x):
                return x

            def __repr__(self):
                return "PASS( dropout_enable {} for {} )".format(
                    self.dropout_enable, self.dropout.__repr__())

        dropout_list = [(n, m, nn, mm) for n, m in net.named_modules()
                        for nn, mm in m._modules.items()
                        if isinstance(mm, torch.nn.Dropout)]
        for n, m, nn, mm in dropout_list:
            m._modules[nn] = PASS(mm)
            logger.info('rank {} {}.{} Dropout in warped by PASS'.format(
                rank, n, nn))

    logger.info('rank {} warp dropout for teacher model'.format(rank))
    wrap_dropout(model_t)

    logger.info('rank {} warp dropout for student model'.format(rank))
    wrap_dropout(model_s)

    #calculate current number of heads in student model
    bert_s = model_s.get_bert()
    n_layers, n_heads = bert_s.config.num_hidden_layers, bert_s.config.num_attention_heads
    if hasattr(bert_s.config, 'pruned_heads'):
        pruned_nums = [len(v) for v in model_s.config.pruned_heads.values()]
        if pruned_nums:
            n_heads -= min(pruned_nums)

    #load train and evaluation datasets
    if check_model_type(model_s, BertModelEMB):
        train_dataset = create_squad_qcemb_dataset(rank, args.device,
                                                   args.squad_train_data,
                                                   tokenizer,
                                                   args.max_seq_length_q,
                                                   args.max_seq_length_c)
        test_dataset = create_squad_qcemb_dataset(rank, args.device,
                                                  args.squad_dev_data,
                                                  tokenizer,
                                                  args.max_seq_length_q,
                                                  args.max_seq_length_c)
    else:
        train_dataset = create_squad_qa_dataset(rank, args.device,
                                                args.squad_train_data,
                                                tokenizer,
                                                args.max_seq_length_q,
                                                args.max_seq_length_c)
        test_dataset = create_squad_qa_dataset(rank, args.device,
                                               args.squad_dev_data, tokenizer,
                                               args.max_seq_length_q,
                                               args.max_seq_length_c)

    if rank in [-1, 0]:
        switch_to_eval(rank, model_t)
        result_t = evaluate(args, model_t, test_dataset)
        for k, v in result_t.items():
            logger.info("{} teacher {}".format(k, v))
    if rank > -1:
        torch.distributed.barrier()

    params_emb = []
    for n, p in model_s.named_parameters():
        if any(p is pp for pp in params_packed) and 'embedding' in n:
            params_emb.append(p)

    if params_emb:
        params_inp = [
            p for n, p in model_s.named_parameters() if 'input_transform' in n
        ]

        #tune embeddings transformation
        params_tune = params_emb + params_inp
        loss_num = 1
        train(rank,
              args,
              tokenizer,
              train_dataset,
              test_dataset,
              model_s,
              model_t,
              params_tune,
              head_importance=None,
              loss_num=loss_num)

        #iterative add bert encoder blocks
        encoder = model_s.get_bert().encoder
        for l, t in zip(encoder.layer, encoder.output_transforms):
            params_tune.extend(l.parameters())
            params_tune.extend(t.parameters())
            loss_num += 1
            train(rank,
                  args,
                  tokenizer,
                  train_dataset,
                  test_dataset,
                  model_s,
                  model_t,
                  params_tune,
                  head_importance=None,
                  loss_num=loss_num)

    if params_packed:
        #on the first stage the FF block only reduced and tuned
        #the number of self attention heads is the same

        #check that head prune is needed and run second train to tune the rest heads
        pack_head_num = int(
            model_s.config.pack_cfg.get('num_attention_heads', n_heads))
        pack_heads_flag = (pack_head_num < n_heads)
        head_importance = torch.zeros(n_layers, n_heads).to(
            args.device) if pack_heads_flag else None

        params_ff = [
            p for n, p in model_s.named_parameters()
            if 'encoder.' in n and 'attention.' not in n
        ]

        train(rank,
              args,
              tokenizer,
              train_dataset,
              test_dataset,
              model_s,
              model_t,
              params_packed + params_ff,
              head_importance=head_importance)

        if head_importance is not None and rank > -1:
            torch.distributed.all_reduce(head_importance.data,
                                         op=torch.distributed.ReduceOp.SUM)

        if pack_heads_flag:
            #reduce number of heads before move to the second stage and tune all model
            if rank in [-1, 0]:
                logger.info('head_importance')
                logger.info(head_importance)
                logger.info('heads_to_prune')

            #prune heads
            heads_to_prune = {}
            for l in range(n_layers):
                imp = head_importance[l].tolist()
                idx = list(sorted(range(n_heads), key=lambda x: imp[x]))
                heads_to_prune[l] = idx[:-pack_head_num]
                if rank in [-1, 0]:
                    logger.info("layer {} heads_to_prune {}".format(
                        l, heads_to_prune[l]))
            model_s.prune_heads(heads_to_prune)
            utils.sync_models(rank, model_s)

        params_encoder = [
            p for n, p in model_s.named_parameters() if 'encoder.' in n
        ]
        params_emb = [
            p for n, p in model_s.named_parameters()
            if 'embedding' in n and 'linear' in n
        ]
        if params_emb:
            # if has linear then LayerNorm was trained
            params_emb += [
                p for n, p in model_s.named_parameters()
                if 'embedding' in n and 'LayerNorm' in n
            ]
        train(rank, args, tokenizer, train_dataset, test_dataset, model_s,
              model_t, params_emb + params_encoder)

    params_encoder = [
        p for n, p in model_s.named_parameters() if 'encoder.' in n
    ]
    params_emb = [
        p for n, p in model_s.named_parameters()
        if 'embedding' in n and 'linear' in n
    ]
    if params_emb:
        #if has linear then LayerNorm was trained
        params_emb += [
            p for n, p in model_s.named_parameters()
            if 'embedding' in n and 'LayerNorm' in n
        ]

    #final tune
    train(rank,
          args,
          tokenizer,
          train_dataset,
          test_dataset,
          model_s,
          model_t,
          params_emb + params_encoder,
          tune_iter=1)

    if rank in [-1, 0]:
        save_model(args, model_s, tokenizer)

        logger.info('Evaluate student model')
        logger.info('Model for evaluation')
        logger.info(model_s)
        switch_to_eval(rank, model_s)
        result_s = evaluate(args, model_s, test_dataset)
        for k, v in result_s.items():
            logger.info("{} student {} teacher {}".format(k, v, result_t[k]))

        #merge some linear transformations into filters
        model_s.merge_()

        logger.info("student model")
        logger.info(model_s)
        result_s = evaluate(args, model_s, test_dataset)
        for k, v in result_s.items():
            logger.info(
                "{} student {} after some operations are merged".format(k, v))

        #save to onnx
        if check_model_type(model_s, BertModelEMB):
            output_names = ['embedding']
        else:
            output_names = ['output_s', 'output_e']
        inputs = tuple(
            torch.zeros(args.max_seq_length_q, dtype=torch.long)
            for t in range(4))
        inputs = tuple(t.unsqueeze(0).to(args.device) for t in inputs)
        torch.onnx.export(model_s,
                          inputs,
                          os.path.join(args.output_dir, "model.onnx"),
                          verbose=False,
                          input_names=[
                              'input_ids', 'attention_mask', 'token_type_ids',
                              'position_ids'
                          ],
                          output_names=output_names)
Exemple #3
0
def process(rank, args, port):
    #init multiprocess
    if rank < 0:
        args.device = torch.device("cpu" if args.n_gpu < 1 else "cuda")
    else:
        # create default process group
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = str(port)
        torch.distributed.init_process_group("nccl",
                                             rank=rank,
                                             world_size=args.n_gpu)
        args.device = torch.device("cuda:{}".format(rank))
        torch.cuda.set_device(rank)

    if rank > 0:
        #wait while process 0 load models
        torch.distributed.barrier()

    printlog("rank", rank, "load tokenizer", args.model_student)
    tokenizer = BertTokenizer.from_pretrained(args.model_student)

    printlog("rank", rank, "load model", args.model_student)
    config = AutoConfig.from_pretrained(args.model_student)
    if config.architectures and 'BertBasedClassPacked' in config.architectures:
        model = BertPacked(BertModelEMB).from_pretrained(
            args.model_student).to(args.device)
    else:
        model = BertModelEMB.from_pretrained(args.model_student).to(
            args.device)

    if args.supervision_weight > 0:
        model_t = BertModelEMB.from_pretrained(args.model_teacher).to(
            args.device)
    else:
        model_t = None

    if rank == 0:
        #wait while other processes load models
        torch.distributed.barrier()

    #create train and evaluate datasets
    train_dataset_qc = create_squad_qcemb_dataset(rank, args.device,
                                                  args.squad_train_data,
                                                  tokenizer,
                                                  args.max_seq_length_q,
                                                  args.max_seq_length_c)

    test_dataset_qc = create_squad_qcemb_dataset(rank, args.device,
                                                 args.squad_dev_data,
                                                 tokenizer,
                                                 args.max_seq_length_q,
                                                 args.max_seq_length_c)

    if rank >= 0:
        #lets sync after data loaded
        torch.distributed.barrier()

    model_controller = None
    if QUANTIZATION:

        if hasattr(model, 'merge_'):
            #if model is packed, then merge some linera transformations before quantization
            model.merge_()

        if rank in [0, -1]:
            #evaluate before quntization
            model.eval()
            result = evaluate(args, model, test_dataset_qc)
            for n, v in result.items():
                logger.info("original {} - {}".format(n, v))
        if rank >= 0:
            torch.distributed.barrier()

        nncf_config = nncf.NNCFConfig.from_json(args.nncf_config)

        class SquadInitializingDataloader(
                nncf.initialization.InitializingDataLoader):
            def get_inputs(self, batch):
                return [], get_inputs(batch, args.device)

        train_dataloader = DataLoader(train_dataset_qc.c_dataset,
                                      sampler=RandomSampler(
                                          train_dataset_qc.c_dataset),
                                      batch_size=args.per_gpu_train_batch_size)

        initializing_data_loader = SquadInitializingDataloader(
            train_dataloader)
        init_range = nncf.initialization.QuantizationRangeInitArgs(
            initializing_data_loader)
        nncf_config.register_extra_structs([init_range])
        model_controller, model = nncf.create_compressed_model(
            model, nncf_config, dump_graphs=True)
        if rank > -1:
            model_controller.distributed()
            utils.sync_models(rank, model)

        if rank in [-1, 0]:
            #evaluate pure initialized int8 model
            model.eval()
            result = evaluate(args, model, test_dataset_qc)
            for n, v in result.items():
                logger.info("int8 {} - {}".format(n, v))

        if rank > -1:
            #lets sync after quantization
            torch.distributed.barrier()

        #tune FQ parameters only
        train(rank,
              args,
              model,
              model_t,
              train_dataset_qc,
              test_dataset_qc,
              fq_tune_only=True,
              model_controller=model_controller)

    #tune whole quantized model
    train(rank,
          args,
          model,
          model_t,
          train_dataset_qc,
          test_dataset_qc,
          fq_tune_only=False,
          model_controller=model_controller)

    if rank in [-1, 0]:
        #save and evaluate result
        os.makedirs(args.output_dir, exist_ok=True)
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        model.eval()

        #get sample to pass for onnx generation
        with torch.no_grad():
            torch.onnx.export(model,
                              tuple(
                                  torch.zeros((1, args.max_seq_length_c),
                                              dtype=torch.long,
                                              device=args.device)
                                  for t in range(4)),
                              os.path.join(args.output_dir, "model.onnx"),
                              verbose=False,
                              enable_onnx_checker=False,
                              opset_version=10,
                              input_names=[
                                  'input_ids', 'attention_mask',
                                  'token_type_ids', 'position_ids'
                              ],
                              output_names=['embedding'])

        # Evaluate final model
        result = evaluate(args, model, test_dataset_qc)
        for n, v in result.items():
            logger.info("{} - {}".format(n, v))
        logger.info("checkpoint final result {}".format(result))
Exemple #4
0
def train(rank, args, model, model_t, train_dataset_qc, test_dataset_qc,
          fq_tune_only, model_controller):
    """ Train the model """
    global train_count
    train_count += 1

    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train model", train_count)
        printlog(model)

    q_dataset = train_dataset_qc.q_dataset

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size

    if fq_tune_only:
        gradient_accumulation_steps = 1
        num_train_epochs = 1
    else:
        gradient_accumulation_steps = args.total_train_batch_size // train_batch_size
        num_train_epochs = args.num_train_epochs

    if rank < 0:
        #single process take all
        q_sampler = RandomSampler(q_dataset)
        q_dataloader = DataLoader(q_dataset,
                                  sampler=q_sampler,
                                  batch_size=train_batch_size,
                                  num_workers=4)
    else:
        #special sampler that divide samples between processes
        q_sampler = torch.utils.data.distributed.DistributedSampler(q_dataset,
                                                                    rank=rank)
        q_dataloader = DataLoader(q_dataset,
                                  sampler=q_sampler,
                                  batch_size=per_gpu_train_batch_size)

    steps_total = int(
        len(q_dataloader) // gradient_accumulation_steps * num_train_epochs)

    # Prepare optimizer and schedule
    named_params, groups = utils.make_param_groups(
        rank,
        model,
        args.
        freeze_list,  #list or str with subnames to define frozen parameters
        args.learning_rate,  #learning rate for no FQ parameters
        0.01,  # learning rate for FQ parameters
        fq_tune_only,  #true if only FQ parameters will be optimized
        model_controller)

    optimizer = AdamW(groups, eps=1e-08, lr=args.learning_rate, weight_decay=0)

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        return 1 - p

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        for n, p in named_params:
            printlog('param for tune', n)
        printlog("fq_tune_only", fq_tune_only)
        printlog("dataset size", len(q_dataset))
        printlog("epoches", num_train_epochs)
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size)
        printlog("n_gpu", args.n_gpu)
        printlog("world_size", world_size)
        printlog("gradient_accumulation_steps", gradient_accumulation_steps)
        printlog("total train batch size",
                 train_batch_size * gradient_accumulation_steps)
        printlog("steps_total", steps_total)

    global_step = 1
    model.zero_grad()
    indicators = collections.defaultdict(list)

    softplus = torch.nn.Softplus()

    loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')])

    hnm_hist = {}

    for epoch in range(math.ceil(num_train_epochs)):
        indicators = collections.defaultdict(list)
        model.train()
        if model_t:
            model_t.train()
        if rank > -1:
            #set epoch to make different samples division betwen process for different epoches
            q_sampler.set_epoch(epoch)

        utils.sync_models(rank, model)
        for step, q_batch in enumerate(q_dataloader):
            epoch_fp = epoch + step / len(q_dataloader)
            if epoch_fp > num_train_epochs:
                break

            losses = []

            context_ids_pos = q_batch[3]
            q_inputs = get_inputs(q_batch, args.device)
            q_outputs = model(**q_inputs,
                              output_hidden_states=(model_t is not None))
            q_vec = q_outputs[0]

            #get positive embeddings
            c_batch = train_dataset_qc.c_dataset[context_ids_pos.detach().data]
            c_inputs = get_inputs(c_batch, args.device)
            c_outputs = model(**c_inputs,
                              output_hidden_states=(model_t is not None))
            c_vec_pos = c_outputs[0]

            if model_t is not None:
                q_emb_s, q_hidden_s = q_outputs
                c_emb_s, c_hidden_s = c_outputs
                with torch.no_grad():
                    q_emb_t, q_hidden_t = model_t(**q_inputs,
                                                  output_hidden_states=True)
                    c_emb_t, c_hidden_t = model_t(**c_inputs,
                                                  output_hidden_states=True)

                def align_and_loss_outputs(out_s, out_t):
                    if len(out_s) != len(out_t):
                        #the student and teacher outputs are not aligned. try to find teacher output for each student output
                        n_s, n_t = len(out_s), len(out_t)
                        out_t = [
                            out_t[(i * (n_t - 1)) // (n_s - 1)]
                            for i in range(n_s)
                        ]
                    assert len(out_s) == len(
                        out_t
                    ), "can not align number of outputs between student and teacher"
                    assert all(
                        s[0] == s[1]
                        for s in zip(out_s[0].shape, out_t[0].shape)
                    ), "output shapes for student and teacher are not the same"
                    return [(s - t.detach()).pow(2).mean()
                            for s, t in zip(out_s, out_t)]

                l_q = align_and_loss_outputs(q_hidden_s, q_hidden_t)
                l_c = align_and_loss_outputs(c_hidden_s, c_hidden_t)

                emb_loss = loss_cfg.get('emb_loss', '')
                if emb_loss == 'L2':
                    l_q.append((q_emb_s - q_emb_t.detach()).pow(2).mean())
                    l_c.append((c_emb_s - c_emb_t.detach()).pow(2).mean())
                elif emb_loss == 'L1':
                    l_q.append((q_emb_s - q_emb_t.detach()).abs().mean())
                    l_c.append((c_emb_s - c_emb_t.detach()).abs().mean())
                elif emb_loss.lower() not in ['', 'none', '0', 'disable']:
                    raise Exception(
                        'emb_loss={} is unsupported'.format(emb_loss))

                losses.extend([args.supervision_weight * l for l in l_c + l_q])

            triplet_num = int(loss_cfg.get('triplet_num', 1))
            if fq_tune_only:
                triplet_num = 0

            if triplet_num > 0:
                #disable grad to select negatives
                with torch.no_grad():
                    hnm_scores = []
                    hnm_idxs = []

                    #check that current step has no HNM conext vector
                    if global_step not in hnm_hist and args.hnm_num > 0:
                        #generate the new one

                        if world_size > 1 and (args.hnm_num % world_size) != 0:
                            #aligh hnm_num per each replica
                            hnm_plus = world_size - (args.hnm_num % world_size)
                            args.hnm_num += hnm_plus
                            logger.warning(
                                "rank {} args.hnm_num increased by {} from {} to {} to be the same after division by {} replicas."
                                .format(rank, hnm_plus,
                                        args.hnm_num - hnm_plus, args.hnm_num,
                                        world_size))

                        # generate random contexts to calc embedding
                        context_ids_all = torch.randint(
                            low=0,
                            high=len(train_dataset_qc.c_dataset),
                            size=[args.hnm_num])

                        if rank < 0:  #single process take all
                            context_ids = context_ids_all
                        else:
                            #broadcast one sigle indicies to all processes
                            context_ids_all = context_ids_all.to(args.device)
                            torch.distributed.broadcast(context_ids_all, 0)
                            context_ids_all = context_ids_all.cpu()

                            #each process take only small part to calc embedding
                            s = ((rank + 0) * args.hnm_num) // world_size
                            e = ((rank + 1) * args.hnm_num) // world_size
                            context_ids = context_ids_all[s:e]

                        batch_size = min(args.hnm_batch_size,
                                         context_ids.shape[0])

                        s, e = 0, batch_size
                        c_outputs = []
                        while e > s:
                            idx = context_ids.detach()[s:e]
                            c_batch = train_dataset_qc.c_dataset[idx]
                            inputs = get_inputs(c_batch, args.device)
                            outputs = model(**inputs,
                                            output_hidden_states=False)
                            c_outputs.append(outputs[0])
                            s, e = e, min(e + batch_size, context_ids.shape[0])

                        context_emb = torch.cat(c_outputs, dim=0)

                        if rank < 0:
                            # single process calculated all
                            context_emb_all = context_emb
                        else:
                            context_emb_list = [
                                torch.zeros_like(context_emb)
                                for _ in range(world_size)
                            ]
                            torch.distributed.all_gather(
                                context_emb_list, context_emb)
                            context_emb_all = torch.cat(context_emb_list,
                                                        dim=0)

                        hnm_hist[global_step] = (context_ids_all,
                                                 context_emb_all)

                        #check history size and crop the oldest one
                        if len(hnm_hist) > args.hnm_hist_num:
                            del hnm_hist[min(hnm_hist.keys())]

                    #calc HNM scores for current question batch
                    for hist_step, (c_idx, c_vec) in hnm_hist.items():
                        w = args.hnm_hist_alpha**(global_step - hist_step)
                        t1 = q_vec[:, None, :]
                        t2 = c_vec[None, :, :]
                        d = (t1 - t2)
                        score = -d.norm(2, dim=-1)
                        score = score * w

                        hnm_scores.append(score)
                        hnm_idxs.append(c_idx)

                    if hnm_scores:
                        #choose the hardest negative if we have scores
                        score = torch.cat(hnm_scores, dim=-1)
                        idx = torch.cat(hnm_idxs, dim=-1)
                        score = score.cpu()
                        pos_mask = (context_ids_pos[:,
                                                    None] == idx[None, :]).to(
                                                        dtype=score.dtype,
                                                        device=score.device)
                        score = (1 - pos_mask) * score + pos_mask * score.min(
                        )  #make positive context with small score to avoid chose it as hard neg
                        hn_idx = score.argmax(dim=1, keepdim=True)

                        context_ids_neg = idx[hn_idx]
                    else:
                        #just random selection in case of no scores for HNM
                        size = (context_ids_pos.shape[0], 1)
                        context_ids_neg = torch.randint(
                            0,
                            len(train_dataset_qc.c_dataset) - 1, size)
                        shift = (context_ids_neg >= context_ids_pos[:, None])
                        context_ids_neg = context_ids_neg + shift.to(
                            dtype=context_ids_neg.dtype)

                d_pos = (q_vec - c_vec_pos).norm(2, dim=-1)
                # get negative embeddings and calc losses
                for neg_index in range(context_ids_neg.shape[1]):
                    ids = context_ids_neg[:, neg_index]
                    c_batch = train_dataset_qc.c_dataset[ids.detach()]
                    inputs = get_inputs(c_batch, args.device)

                    outputs = model(**inputs, output_hidden_states=False)
                    c_vec_neg = outputs[0]

                    for triplet_index in range(triplet_num):

                        if triplet_index == 0:
                            d_neg = (q_vec - c_vec_neg).norm(2, dim=-1)
                        if triplet_index == 1:
                            d_neg = (c_vec_pos - c_vec_neg).norm(2, dim=-1)

                        d_diff = d_pos - d_neg

                        indicators['dd' + str(triplet_index)].append(
                            [v.mean().item() for v in (d_pos, d_neg, d_diff)])

                        l = softplus(d_diff)
                        losses.append(l)

                        del d_neg
                del d_pos

                #average over batch
                losses = [l.mean() for l in losses]

            l = sum(losses) / len(losses)
            (l / gradient_accumulation_steps).backward()

            indicators['loss'].append(l.item())
            indicators['ll'].append([lll.item() for lll in losses])

            #del losses
            del l

            if (step + 1) % gradient_accumulation_steps == 0:

                utils.sync_grads(rank,
                                 named_params,
                                 report_no_grad_params=(global_step == 1))
                torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if global_step % 10 == 0:
                    # Log metrics
                    wall_time = epoch + step / len(q_dataloader)

                    lrp = [
                        '{:.2f}'.format(i)
                        for i in np.log10(scheduler.get_last_lr())
                    ]

                    str_out = "{} ep {:.2f} lrp {}".format(
                        train_count, epoch_fp, " ".join(lrp))

                    for k, v in indicators.items():
                        v = np.array(v)
                        if len(v.shape) == 1:
                            v = v[:, None]

                        if rank > -1:
                            #sync indicators
                            vt = torch.tensor(v).to(args.device)
                            torch.distributed.all_reduce(
                                vt, op=torch.distributed.ReduceOp.SUM)
                            v = vt.cpu().numpy() / float(world_size)

                        str_out += " {} {}".format(
                            k,
                            " ".join(["{:.3f}".format(t) for t in v.mean(0)]))

                    if 'score' in locals():
                        str_out += " SS {}".format(list(score.shape))

                    if 'time_last' in locals():
                        dt_iter = (time.time() - time_last) / len(
                            indicators['loss'])
                        dt_ep = dt_iter * len(q_dataloader)
                        str_out += " it {:.1f}s".format(dt_iter)
                        str_out += " ep {:.1f}m".format(dt_ep / (60))
                        str_out += " eta {:.1f}h".format(
                            dt_ep * (num_train_epochs - epoch_fp) / (60 * 60))
                    time_last = time.time()

                    indicators = collections.defaultdict(list)
                    if rank in [-1, 0]:
                        logger.info(str_out)

        if rank in [-1, 0]:
            check_point_name = 'checkpoint-{:02}'.format(train_count)
            check_point_name = check_point_name + '-{:02}'.format(epoch + 1)
            result_s = evaluate(args, model.eval(), test_dataset_qc)
            for k, v in result_s.items():
                logger.info("{} {} {}".format(check_point_name, k, v))
        if rank > -1:
            torch.distributed.barrier()
    def train(self, epoch_start):

        global_step = 0
        self.check_loss_raise = CheckLossRaise()
        for epoch in range(epoch_start, math.ceil(self.args.num_train_epochs)):
            self.indicators = collections.defaultdict(list)

            utils.sync_models(self.rank, self.model)

            self.model.train()

            self.model.zero_grad()
            grad_count = 0

            if self.rank > -1:
                #set epoch to make different samples division betwen proceses for different epoches
                self.dataloader.sampler.set_epoch(epoch)

            for step, batch in enumerate(self.dataloader):
                epoch_fp = epoch + step/len(self.dataloader)
                if epoch_fp > self.args.num_train_epochs:
                    break

                x_noise, x_clean = [t.to(self.args.device) for t in batch]

                #augment and mix signals
                x_clean, x_noise, x = self.mix_signals(x_clean, x_noise)

                #forward pass
                y_clean, Y_clean, _ = self.model(x)

                #calc specter for clean input signal
                tail_size = self.model.wnd_length - self.model.hop_length
                X_clean = self.model.encode(torch.nn.functional.pad(x_clean, (tail_size, 0)))

                # crop target and model output to align to each other
                sample_ahead = self.model.get_sample_ahead()
                spectre_ahead = self.model.ahead
                if sample_ahead > 0:
                    x = x[:, :-sample_ahead]
                    x_clean = x_clean[:, :-sample_ahead]
                    y_clean = y_clean[:, sample_ahead:]
                if spectre_ahead > 0:
                    Y_clean = Y_clean[:, :, :, spectre_ahead:]
                    X_clean = X_clean[:, :, :, :-spectre_ahead]

                loss = self.loss(epoch_fp, y_clean, Y_clean, x_clean, X_clean)
                self.indicators['loss'].append(loss.item())

                #calculate and accumulate gradients
                loss.backward()
                grad_count += 1

                #continue if not all gradients were accumulated
                if grad_count < self.gradient_accumulation_steps:
                    continue

                #make optimization step
                utils.sync_grads(self.rank, self.named_params, global_step==0, grad_count)
                self.optimizer.step()  # make optimization step
                self.scheduler.step()  # Update learning rate schedule
                global_step += 1

                self.model.zero_grad()
                grad_count = 0

                #make logs only after several steps
                if global_step % self.args.logacc != 0:
                    continue

                #average indicator over GPUs and iterations
                self.aver_indicators()

                #check that negsisdr suddenly raise
                #if high raise detected then model parameters are restored and optimizer is reset
                self.check_loss_raise.check(
                    self.indicators_mean["negsisdr"],
                    self.named_params,
                    self.optimizer
                )

                self.log_indicators(epoch_fp)

                self.indicators = collections.defaultdict(list)

            self.save_and_eval_checkpoint(epoch+1)
Exemple #6
0
def train(rank, args, model, model_t, train_dataset_qa, test_dataset_qa, scale_tune):
    """ Train the model """
    global train_count
    train_count += 1
    world_size = 1 if rank < 0 else torch.distributed.get_world_size()

    if rank in [-1, 0]:
        printlog("Train model",train_count)
        printlog(model)

    per_gpu_train_batch_size = args.per_gpu_train_batch_size
    train_batch_size = per_gpu_train_batch_size * world_size
    gradient_accumulation_steps = args.total_train_batch_size // train_batch_size
    num_train_epochs = args.num_train_epochs

    if scale_tune:
        gradient_accumulation_steps = 1
        num_train_epochs = 1

    if rank < 0:
        #single process take all samples
        sampler = RandomSampler(train_dataset_qa)
        dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=train_batch_size, num_workers=4)
    else:
        #special sampler that divide samples beween processes
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_qa, rank=rank)
        dataloader = DataLoader(train_dataset_qa, sampler=sampler, batch_size=per_gpu_train_batch_size)

    steps_total = int(len(dataloader) // gradient_accumulation_steps * num_train_epochs)

    # Prepare optimizer and schedule
    freeze_list = args.freeze_list.split(',') if args.freeze_list else []
    named_params = []
    for n, p in model.named_parameters():
        if n.lower()!="none" and any(fn in n for fn in freeze_list):
            if rank in [-1, 0]:
                logger.warning("rank {} {} param is frozen and excluded from tune".format(rank,n))
            continue
        named_params.append( (n, p) )

    # split parameters to scale and the rest
    named_params_scale = [(n, p) for n, p in named_params if '.scale' in n]
    named_params_rest = [(n, p) for n, p in named_params if '.scale' not in n]

    if scale_tune:
        #keep only scale parameters
        named_params = named_params_scale
        named_params_rest = []

    groups = []
    if named_params_scale:
        groups.append({'params': [p for n, p in named_params_scale], 'lr': 0.01})
    if named_params_rest:
        groups.append({'params': [p for n, p in named_params_rest],  'lr': args.learning_rate})

    optimizer = AdamW(
        groups,
        eps=1e-08,
        lr=args.learning_rate,
        weight_decay=0)

    def lr_lambda(current_step):
        p = float(current_step) / float(steps_total)
        return 1 - p

    scheduler = LambdaLR(optimizer, lr_lambda)

    if rank in [-1, 0]:
        for n,p in named_params:
            printlog('param for tune',n)
        printlog("scale_tune", scale_tune )
        printlog("dataset size", len(train_dataset_qa) )
        printlog("epoches", num_train_epochs )
        printlog("per_gpu_train_batch_size", per_gpu_train_batch_size )
        printlog("n_gpu", args.n_gpu )
        printlog("world_size", world_size )
        printlog("gradient_accumulation_steps", gradient_accumulation_steps )
        printlog("total train batch size", train_batch_size * gradient_accumulation_steps )
        printlog("steps_total",steps_total )

    global_step = 0
    model.zero_grad()
    indicators = collections.defaultdict(list)

    softplus = torch.nn.Softplus()

    loss_cfg = dict([t.split(':') for t in args.loss_cfg.split(',')]) if args.loss_cfg else dict()

    for epoch in range(math.ceil(num_train_epochs)):
        indicators = collections.defaultdict(list)
        model.train()
        set_output_hidden_states(rank, model, (model_t is not None))
        utils.sync_models(rank, model)
        if model_t is not None:
            set_output_hidden_states(rank, model_t, True)
            model_t.train()
        if rank > -1:
            #set epoch to make different samples division betwen process for different epoches
            sampler.set_epoch(epoch)

        for step, batch in enumerate(dataloader):
            epoch_fp = epoch + step/len(dataloader)
            if epoch_fp > num_train_epochs:
                break

            epoch_fp = epoch + step/len(dataloader)

            losses = []

            inputs = get_inputs(batch, args.device)
            targets = get_targets(batch, args.device)
            outputs = model(**inputs, **targets, output_hidden_states=(model_t is not None))
            losses.append(outputs[0])
            outputs = outputs[1:]

            if model_t is not None:
                with torch.no_grad():
                    outputs_t = model_t(**inputs, output_hidden_states=True)
                    hidden_t = outputs_t[2]
                    assert isinstance(hidden_t, (tuple,list)), "hidden states output is not detected right"
                    assert len(hidden_t) == model_t.config.num_hidden_layers+1, "hidden states output is not detected right"

                if args.kd_weight>0:
                    # Calculate knowladge distilation loss
                    kd_losses = []
                    for logit_s,logit_t in zip(outputs[0:2],outputs_t[0:2]):
                        T = 1
                        prob_t = torch.nn.functional.softmax(logit_t.detach() / T, dim=1)
                        logprob_s = torch.nn.functional.log_softmax(logit_s / T, dim=1)
                        kd_losses.append( -(logprob_s * prob_t).mean() * (T * T * prob_t.shape[1]) )
                    losses.append(args.kd_weight*sum(kd_losses)/len(kd_losses))


                hidden_s = outputs[2]
                assert isinstance(hidden_s, (tuple,list)), "hidden states output is not detected right"
                assert len(hidden_s) == model.config.num_hidden_layers+1, "hidden states output is not detected right"

                def align_and_loss_outputs(out_s, out_t):
                    if len(out_s) != len(out_t):
                        #the student and teacher outputs are not aligned. try to find teacher output for each student output
                        n_s, n_t = len(out_s), len(out_t)
                        out_t = [out_t[(i*(n_t-1))//(n_s-1)] for i in range(n_s)]
                    assert len(out_s) == len(out_t), "can not align number of outputs between student and teacher"
                    assert all(s[0] == s[1] for s in zip(out_s[0].shape, out_t[0].shape)), "output shapes for student and teacher are not the same"
                    return [(s - t.detach()).pow(2).mean() for s,t in zip(out_s, out_t)]

                sw_losses = align_and_loss_outputs(hidden_s,hidden_t)

                losses.extend([args.supervision_weight*l for l in sw_losses])

            #average over batch
            losses = [l.mean() for l in losses]

            l = sum(losses)/len(losses)
            indicators['loss'].append(l.item())
            indicators['ll'].append([lll.item() for lll in losses])

            (l/gradient_accumulation_steps).backward()

            del l

            if (step + 1) % gradient_accumulation_steps == 0:
                global_step += 1

                utils.sync_grads(rank, named_params, report_no_grad_params=(global_step==1))
                torch.nn.utils.clip_grad_norm_([p for n, p in named_params], 1)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()


                if global_step % 50 == 0:
                    # Log metrics
                    wall_time = epoch + step / len(dataloader)

                    lrp = " ".join(['{:.2f}'.format(t) for t in np.log10(scheduler.get_last_lr())])
                    str_out = "{} ep {:.2f} lrp {}".format(train_count, epoch_fp, lrp)

                    for k,v in indicators.items():
                        v = np.array(v)
                        if len(v.shape)==1:
                            v = v[:,None]

                        if rank>-1:
                            #sync indicators
                            vt = torch.tensor(v).to(args.device)
                            torch.distributed.all_reduce(vt, op=torch.distributed.ReduceOp.SUM)
                            v = vt.cpu().numpy() / float(world_size)

                        str_out += " {} {}".format(k," ".join(["{:.3f}".format(t) for t in v.mean(0)]))


                    if 'time_last' in locals():
                        #estimate processing times
                        dt_iter = (time.time() - time_last) / len(indicators['loss'])
                        dt_ep = dt_iter * len(dataloader)
                        str_out += " it {:.1f}s".format(dt_iter)
                        str_out += " ep {:.1f}m".format(dt_ep / (60))
                        str_out += " eta {:.1f}h".format(dt_ep * (num_train_epochs - epoch_fp) / (60 * 60))
                    time_last = time.time()

                    indicators = collections.defaultdict(list)
                    if rank in [-1, 0]:
                        logger.info(str_out)

        if rank in [-1, 0]:
            check_point_name = 'checkpoint-{:02}'.format(train_count)
            check_point_name = check_point_name + '-{:02}'.format(epoch + 1)
            model.eval()
            set_output_hidden_states(rank, model, False)
            result_s = evaluate(args, model, test_dataset_qa)
            for k,v in result_s.items():
                logger.info("{} {} {}".format(check_point_name, k, result_s[k]))
        if rank>-1:
            torch.distributed.barrier()