Esempio n. 1
0
def forward_step(data_iterator, model, args, timers):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    batch = get_batch(data_iterator, args, timers)
    if batch is None:
        return None
    tokens, lm_labels, attention_mask, position_ids, loss_mask = batch
    timers('batch generator').stop()
    # Forward model.
    if args.eval_hf:
        output, _ = model(tokens)
    else:
        output = model(tokens, position_ids, attention_mask)

    if not args.cloze_eval:
        # losses = torch.nn.CrossEntropyLoss(reduce=False)(
        losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                                  lm_labels.contiguous())
        loss_mask = loss_mask.contiguous()
        loss_mask = loss_mask.view(-1)
        lm_loss = torch.sum(losses.view(-1) * loss_mask.float())
    else:
        outputs = torch.argmax(output, -1).contiguous().view(-1)
        acc = (outputs == lm_labels.contiguous().view(-1)).float()
        loss_mask = loss_mask.contiguous().view(-1).float()
        lm_loss = torch.sum(acc * loss_mask)

    return lm_loss
def forward_step(data_iterator, model, args, timers):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data_iterator, args, timers)
    timers('batch generator').stop()

    # Forward model.
    output, *other_losses = model(tokens, position_ids, attention_mask)
    
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
    
    moe_losses = []
    for moe_loss in other_losses:
        if moe_loss is not None:
            moe_losses.append(moe_loss)      

    
    #print(f"Moe Losses: {moe_losses}, actual loss {loss}")
    moe_loss = sum(moe_losses)
    #if torch.distributed.get_rank() == 0:
    #    print(f"Moe Loss {moe_loss}")
    loss = loss + moe_loss
    
    return loss
Esempio n. 3
0
def seq2seq_forward_step(data, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    if timers is not None:
        timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)
    if timers is not None:
        timers('batch generator').stop()
    # Forward model.
    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    # logits, loss_mask = logits[:, args.src_seq_length:], loss_mask[:, args.src_seq_length:]
    # target_ids = target_ids[:, args.src_seq_length:]
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    if args.label_smoothing > 0.0:
        epsilon = args.label_smoothing
        smooth_loss = -torch.nn.functional.log_softmax(logits,
                                                       dim=-1).mean(dim=-1)
        losses = (1 - epsilon) * losses + epsilon * smooth_loss
    loss_mask = loss_mask.reshape(-1)
    # The loss is not normalized for fair comparison
    loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum()
    return loss, mems, 'bert'
Esempio n. 4
0
def forward_step(data_iterator, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    timers('data loader').start()
    data = next(data_iterator) if data_iterator else None
    timers('data loader').stop()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)
    timers('batch generator').stop()

    if data is not None and "mode" in data:
        mode = data['mode']
    else:
        mode = 'bert'

    # Forward model.
    if args.nonautoregressive:
        logits, na_logits, *mems = model(tokens, position_ids, attention_mask,
                                         *mems)
        losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                                  labels)
        loss_mask = loss_mask.view(-1)
        loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

        na_losses = mpu.vocab_parallel_cross_entropy(
            na_logits.contiguous().float(), labels)
        na_loss = torch.sum(na_losses.view(-1) * loss_mask) / loss_mask.sum()
        loss = loss + na_loss
    else:
        logits, *mems = model(tokens, position_ids, attention_mask, *mems)
        losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                                  labels)
        loss_mask = loss_mask.view(-1)
        loss = torch.sum(losses.view(-1) * loss_mask)
        if loss_mask.sum().item() > 0:
            loss = loss / loss_mask.sum()

    return loss, mems, mode
Esempio n. 5
0
def forward_step(data_iterator, model, args, timers):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data_iterator, args, timers)
    timers('batch generator').stop()

    # Forward model.
    output = model(tokens, position_ids, attention_mask)
    losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    return loss
def forward_step(data_iterator, model, args, timers):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    #if torch.distributed.get_rank() == 0:
    #    print("AAAAAAAAAAAAAAAAAA")
    tokens, types, loss_mask, lm_labels, padding_mask, attention_mask, position_ids, clicklabels, hrslabels = get_batch(
        data_iterator, args, timers)
    #if torch.distributed.get_rank() == 0:
    #    print("BBBBBBBBBBBBBBBBBB")
    timers('batch generator').stop()

    # Forward model.
    output, hrs_scores, click_scores, *other_losses = model(tokens, position_ids, attention_mask, types)
    #pooled_output = torch.squeeze(output[:,0,:])
    
    losses = mpu.vocab_parallel_cross_entropy(
        output.contiguous().float(), lm_labels.contiguous())
    loss_mask = loss_mask.contiguous()
    lm_loss = torch.sum(
        losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum()
    loss = lm_loss

    moe_losses = []
    for moe_loss in other_losses:
        if moe_loss is not None:
            moe_losses.append(moe_loss)      

    hrs_loss = PairwiseHRSLoss(hrs_scores.contiguous(), hrslabels.contiguous(), 4)
    click_loss = PairwiseClickLoss(click_scores.contiguous(), clicklabels.contiguous(), 4)
    
    #print(f"Moe Losses: {moe_losses}, actual loss {loss}")
    moe_loss = sum(moe_losses)
    #if torch.distributed.get_rank() == 0:
    #    print(f"Moe Loss {moe_loss}")
    loss = loss + moe_loss * 0.1 + hrs_loss + click_loss
    
    return loss,  hrs_scores, hrslabels, lm_loss, hrs_loss, click_loss
def forward_step(data_iterator, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data_iterator, args, timers)
    # global last_tokens
    # last_tokens = tokens.tolist()
    # if last_tokens is not None:
    #     for i in range(len(tokens)):
    #         if tokens[i][0] != 0 and tokens[i][0] != last_tokens[i] + 1:
    #             breakpoint()
    # last_tokens = tokens[:, -1].tolist()
    timers('batch generator').stop()

    # Forward model.
    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    return loss, mems
Esempio n. 8
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = GPT2Tokenizer(
        os.path.join(args.tokenizer_path, 'vocab.json'),
        os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # load data
    test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1)
    # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval.
    args.train_iters = 1

    # Model
    model, _, _ = setup_model_and_optimizer(args)

    device = torch.cuda.current_device()

    # give a time stemp to the model
    cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    results_dir = os.path.join(args.results_dir,
                               "{}-{}".format(args.model_name, cur_time))

    if torch.distributed.get_rank() == 0:
        os.makedirs(results_dir, exist_ok=True)

    model.eval()
    all_sids = []
    all_cids = []
    all_losses = []
    with torch.no_grad():
        for batch, no_model_batch in tqdm(
                test_dataloader,
                desc="Evaluating",
                disable=(torch.distributed.get_rank() != 0)):
            for k in batch:
                batch[k] = batch[k].to(device)
            for k in no_model_batch:
                no_model_batch[k] = no_model_batch[k].to(device)

            output = model(**batch)
            losses = mpu.vocab_parallel_cross_entropy(
                output.contiguous().float(), no_model_batch["labels"])
            loss_mask = no_model_batch["loss_mask"]
            loss = torch.sum(losses * loss_mask,
                             dim=-1) / loss_mask.sum(dim=-1)

            loss_tensor_list = [
                torch.zeros_like(loss).to(device)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(loss_tensor_list,
                                         loss.data,
                                         group=mpu.get_data_parallel_group())
            all_losses.extend(loss_tensor_list)

            sids = no_model_batch["sids"]
            sid_tensor_list = [
                torch.zeros_like(sids)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(sid_tensor_list,
                                         sids.data,
                                         group=mpu.get_data_parallel_group())
            all_sids.extend(sid_tensor_list)

            cids = no_model_batch["cids"]
            cid_tensor_list = [
                torch.zeros_like(cids)
                for _ in range(mpu.get_data_parallel_world_size())
            ]
            torch.distributed.all_gather(cid_tensor_list,
                                         cids.data,
                                         group=mpu.get_data_parallel_group())
            all_cids.extend(cid_tensor_list)

    if torch.distributed.get_rank() == 0:
        all_losses = torch.stack(all_losses).view(-1).cpu().detach().numpy()
        all_sids = torch.stack(all_sids).view(-1).cpu().detach().numpy()
        all_cids = torch.stack(all_cids).view(-1).cpu().detach().numpy()

        truth_labels = test_dataset.truth_labels
        preds = [[] for _ in truth_labels]

        for sid, cid, loss in zip(all_sids, all_cids, all_losses):
            preds[sid].append((cid, loss))

        preds = [min(p, key=lambda x: x[1])[0] for p in preds if len(p) > 0]

        yprint("Acc: {}".format(
            sum([int(p == l)
                 for p, l in zip(preds, truth_labels)]) / len(truth_labels)))
        with open(os.path.join(results_dir, "zero-shot_result.txt"), "w") as f:
            f.write("Acc: {}\n".format(
                sum([int(p == l) for p, l in zip(preds, truth_labels)]) /
                len(truth_labels)))

    torch.distributed.barrier()
Esempio n. 9
0
def forward_step(data_iterator, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    timers('batch generator').start()
    timers('data loader').start()
    rand = random.Random(args.iteration * mpu.get_data_parallel_world_size() +
                         mpu.get_data_parallel_rank())
    if data_iterator[1] and rand.random() < args.multi_task_ratio:
        data = next(data_iterator[1]) if data_iterator[1] else None
        data["mode"] = "multi-task"
    else:
        data = next(data_iterator[0]) if data_iterator[0] else None
    # print_rank_0("data iterator")
    timers('data loader').stop()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)
    timers('batch generator').stop()

    # print_rank_0("get batch")

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        sep = attention_mask.item() if torch.numel(
            attention_mask) == 1 else attention_mask[batch_id].item()
        text, last_segment = "", []
        for i, token_id in enumerate(tokens[batch_id, :sep].tolist()):
            token = tokenizer.IdToToken(token_id)
            if token.startswith('[MASK') or token.endswith('MASK]'):
                if last_segment:
                    text += tokenizer.DecodeIds(last_segment)
                    last_segment = []
                text += f" [{position_ids_[batch_id, i].item()}, {token}]"
            else:
                last_segment.append(token_id)
        if last_segment:
            text += tokenizer.DecodeIds(last_segment)
        print(text.encode('utf-8'))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if tokenizer.IdToToken(
                    tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(
                        tokenizer.DecodeIds(
                            tokens[batch_id,
                                   last_index:i].tolist()).encode('utf-8'),
                        "|",
                        tokenizer.DecodeIds(
                            labels[batch_id,
                                   last_index:i].tolist()).encode('utf-8'),
                        position_ids_[batch_id, last_index:i].tolist(),
                        block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(
                tokenizer.DecodeIds(
                    tokens[batch_id,
                           last_index:].tolist()).encode('utf-8'), "|",
                tokenizer.DecodeIds(
                    labels[batch_id, last_index:].tolist()).encode('utf-8'),
                position_ids_[batch_id, last_index:].tolist(),
                block_position_ids[batch_id, last_index:].tolist())

    if data is not None and "mode" in data:
        mode = data['mode']
    else:
        mode = 'bert'

    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    loss_mask = loss_mask.view(-1)
    loss = torch.sum(losses.view(-1) * loss_mask)
    if loss_mask.sum().item() > 0:
        loss = loss / loss_mask.sum()

    return loss, mems, mode
Esempio n. 10
0
def main():
    """Main training program."""

    # Disable CuDNN.
    torch.backends.cudnn.enabled = False

    # Timer.
    timers = Timers()

    # Arguments.
    args = get_args()

    # Pytorch distributed.
    initialize_distributed(args)

    # Random seeds for reproducability.
    set_random_seed(args.seed)

    # get the tokenizer
    tokenizer = GPT2Tokenizer(
        os.path.join(args.tokenizer_path, 'vocab.json'),
        os.path.join(args.tokenizer_path, 'chinese_vocab.model'))

    # load train data
    if args.do_train:
        train_dataloader, _ = load_data(args, 'train', tokenizer, 1)
        dev_dataloader, dev_dataset = load_data(args, 'dev', tokenizer, 1)

        with open(args.deepspeed_config, "r") as f:
            deepspeed_conf = json.load(f)

        epoch = args.epoch
        grad_acc = deepspeed_conf["gradient_accumulation_steps"]
        args.train_iters = len(train_dataloader) * epoch / grad_acc

        # Model, optimizer, and learning rate.
        # TODO: maybe need to reinitialize optimizer
    elif args.do_eval:
        # Set an arbitrary positive integer since the optimizer and the scheduler will not be used when do eval.
        args.train_iters = 1

    model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
    device = torch.cuda.current_device()

    # give a time stemp to the model
    cur_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
    results_dir = os.path.join(args.results_dir,
                               "{}-{}".format(args.model_name, cur_time))
    os.makedirs(results_dir, exist_ok=True)

    if args.do_train and torch.distributed.get_rank() == 0:

        with open(os.path.join(results_dir, "train_log.txt"), "w") as f:
            f.write("Train losses:\n")

        with open(os.path.join(results_dir, "dev_log.txt"), "w") as f:
            f.write("Dev accs:\n")

    torch.distributed.barrier()

    if args.do_train:
        cand_ids = torch.tensor(dev_dataset.cand_ids).to(device)
        total_loss, logging_loss, best_acc = 0.0, 0.0, 0.0
        global_step, total_step, best_step = 0, 0, 0

        for e in range(epoch):
            model.train()
            for batch, no_model_batch in tqdm(
                    train_dataloader,
                    disable=(torch.distributed.get_rank() != 0)):
                for k in batch:
                    batch[k] = batch[k].to(device)
                for k in no_model_batch:
                    no_model_batch[k] = no_model_batch[k].to(device)

                output = model(**batch)
                # get the loss of the last token
                output = torch.sum(
                    output * no_model_batch["loss_mask"].unsqueeze(-1),
                    1) / torch.sum(no_model_batch["loss_mask"],
                                   -1).unsqueeze(-1)
                # get the label of the last token
                labels = no_model_batch["labels"].float()
                labels = (torch.sum(labels * no_model_batch["loss_mask"], 1) /
                          torch.sum(no_model_batch["loss_mask"], -1)).long()
                # cross_entropy loss
                losses = mpu.vocab_parallel_cross_entropy(
                    output.unsqueeze(1).contiguous().float(),
                    labels.unsqueeze(1))
                loss = torch.mean(losses)

                model.backward(loss)
                model.step()

                torch.distributed.all_reduce(
                    loss.data, group=mpu.get_data_parallel_group())
                loss.data = loss.data / mpu.get_data_parallel_world_size()
                total_loss += loss.item() / grad_acc

                if total_step % grad_acc == 0:
                    global_step += 1
                    if global_step != 0 and global_step % args.log_interval == 0:
                        # logging
                        if torch.distributed.get_rank() == 0:
                            train_log = "Epoch {}, global step {}, total step {}, train lm loss: {}".format(
                                e, global_step, epoch * len(train_dataloader),
                                (total_loss - logging_loss) /
                                args.log_interval)
                            yprint(train_log)
                            with open(
                                    os.path.join(results_dir, "train_log.txt"),
                                    "a") as f:
                                f.write(train_log + "\n")

                        logging_loss = total_loss

                    if global_step != 0 and global_step % args.eval_interval == 0:
                        # evaluate on the dev
                        acc, _, _ = evaluate(args,
                                             model,
                                             dev_dataloader,
                                             cand_ids,
                                             device,
                                             mode="dev")
                        dev_results_dir = os.path.join(
                            results_dir, "dev_step-{}".format(global_step))

                        if acc > best_acc:
                            best_acc = acc
                            best_step = global_step

                        if torch.distributed.get_rank() == 0:
                            # we will only write the log file once
                            dev_log = "Epoch: {}, Global step: {}, Acc: {}".format(
                                e, global_step, acc)
                            yprint(dev_log)
                            os.makedirs(dev_results_dir, exist_ok=True)
                            with open(
                                    os.path.join(dev_results_dir,
                                                 "dev_result.txt"), "w") as f:
                                f.write(dev_log + "\n")
                            with open(os.path.join(results_dir, "dev_log.txt"),
                                      "a") as f:
                                f.write(dev_log + "\n")

                        torch.distributed.barrier()

                        args.save = dev_results_dir
                        save_checkpoint(global_step, model, optimizer,
                                        lr_scheduler, args)

                total_step += 1

        with open(os.path.join(dev_results_dir, "dev_log.txt"), "a") as f:
            f.write("Best acc: {} Best step: {}\n".format(best_acc, best_step))

    if args.do_eval:
        # evaluate on the test
        test_dataloader, test_dataset = load_data(args, 'test', tokenizer, 1)
        cand_ids = torch.tensor(test_dataset.cand_ids).to(device)

        if args.do_train:
            # if do training, then evaluate the one with the max acc on dev set.
            eval_ckpt_path = os.path.join(results_dir,
                                          "dev_step-{}".format(best_step))
            args.load = eval_ckpt_path
        else:
            # if only do eval, then evaluate the one specified by the user.
            args.load = args.eval_ckpt_path

        load_checkpoint(model=model,
                        optimizer=None,
                        lr_scheduler=None,
                        args=args)
        acc, _, _ = evaluate(args,
                             model,
                             test_dataloader,
                             cand_ids,
                             device,
                             mode="test")

        if torch.distributed.get_rank() == 0:
            eval_log = "Checkpoint from {}: Acc: {}".format(args.load, acc)
            yprint(eval_log)
            with open(os.path.join(results_dir, "eval_log"), "w") as f:
                f.write(eval_log + "\n")

        torch.distributed.barrier()
Esempio n. 11
0
def lm_forward_step(data, model, args, timers, mems, eval_metric=None):
    """Forward step."""

    # Get the batch.
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        output_tokens = []
        sep = attention_mask[batch_id].item()
        for i, token in enumerate(tokens[batch_id, :sep].tolist()):
            if global_tokenizer is not None:
                token = global_tokenizer.IdToToken(token)
                if token.startswith('[MASK'):
                    token = f"[{position_ids_[batch_id, i].item()}, {token}]"
                if token.startswith('##') and len(
                        output_tokens) > 0 and not output_tokens[-1].endswith(
                            ']'):
                    output_tokens[-1] += token[2:]
                else:
                    output_tokens.append(token)
            else:
                output_tokens.append(str(token))
        print(" ".join(output_tokens))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if global_tokenizer.IdToToken(
                    tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(
                        global_tokenizer.DecodeIds(
                            tokens[batch_id, last_index:i].tolist()), "|",
                        global_tokenizer.DecodeIds(
                            labels[batch_id, last_index:i].tolist())),
                    print(position_ids_[batch_id, last_index:i].tolist(),
                          block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(
                global_tokenizer.DecodeIds(tokens[batch_id,
                                                  last_index:].tolist()), "|",
                global_tokenizer.DecodeIds(labels[batch_id,
                                                  last_index:].tolist()))
            print(position_ids_[batch_id, last_index:].tolist(),
                  block_position_ids[batch_id, last_index:].tolist())

    # Forward model.
    if args.continuous_prompt:
        prompt_pos = data["prompt_pos"].long().cuda()
        logits, *mems = model(tokens,
                              position_ids,
                              attention_mask,
                              *mems,
                              prompt_pos=prompt_pos)
    else:
        logits, *mems = model(tokens, position_ids, attention_mask, *mems)

    if eval_metric is None or eval_metric == 'loss':
        losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                                  labels)
        loss_mask = loss_mask.view(-1)
        # The loss is not normalized for fair comparison
        loss = torch.sum(losses.view(-1) * loss_mask)
        if eval_metric is None:
            loss = loss / loss_mask.sum()
        return loss, mems, 'bert'
    elif eval_metric == 'accuracy' or eval_metric == 'classify':
        outputs = torch.argmax(logits, -1)
        correct = (outputs == labels).float()
        correct[(1 - loss_mask).bool()] = 1
        correct = correct.prod(-1)
        if eval_metric == 'accuracy':
            correct = correct.sum()
        return correct, mems, 'bert'
    else:
        raise NotImplementedError(
            "Metric {} not implemented".format(eval_metric))
    def train(self):
        self.io_fac.logging("Model Loaded from {} ".format(
            self.model_weight_path))
        self.steps = 0
        self.running_loss = 0
        self.company_accuracy = 0
        self.io_fac.logging("START TRAINING!!!")
        for epoch in range(self.start_epoch, self.start_epoch + self.epoch):
            self.net.train()
            if epoch != self.start_epoch:
                self.optim_fac.lr_step(epoch, self.lr_decay)
            epoch_start = time.time()
            idx = 0
            for optim in self.optim_fac.optimizers:
                state = optim.state_dict()["param_groups"][0]
                state = [
                    "{} : {}\n".format(k, v) for k, v in state.items()
                    if k != "params"
                ]
                self.io_fac.logging(state)

            # for i,(imgs,labels) in enumerate(self.train_loader.sample_loader):
            fps_start = time.time()
            ###########################if use tfrecord file as data####################################
            #for data in tqdm(self.train_loader.sample_loader):
            #    imgs = data["image"].cuda(self.gpu,non_blocking=True)
            #    labels = torch.squeeze(data["label"],dim=1).cuda(self.gpu,non_blocking=True).long()
            ################################use imgs from Folders##########################
            for images, label in tqdm(self.train_loader.sample_loader):
                imgs = images.cuda(self.gpu, non_blocking=True)
                labels = label.cuda(self.gpu, non_blocking=True).long()
                ##################################################
                embeddings = self.net(imgs)
                logist, gather_label = self.header(embeddings, labels)
                mpu_loss = mpu.vocab_parallel_cross_entropy(
                    logist.contiguous().float(), gather_label)
                loss = torch.sum(
                    mpu_loss.view(-1)) / (self.world_size * self.batch_size)
                loss_val = loss.data.detach()
                self.optim_fac.reset()
                if self.use_fp16:
                    with amp.scale_loss(
                            loss, self.optim_fac.optimizers) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                self.optim_fac.step()
                self.running_loss += loss_val
                fps_time_cost = time.time() - fps_start
                self.steps += 1
                idx += 1
                if self.steps % self.board_loss_every == 0:
                    self.visual_disp(epoch, idx, loss_val, imgs, embeddings,
                                     labels, fps_time_cost)
                if self.steps % self.evaluate_every == 0:
                    self.online_val()
                if self.steps % self.save_every == 0:
                    self.save_state(self.company_accuracy, True, extra='test')
                fps_start = time.time()
            self.epoch_time.update(time.time() - epoch_start)
            print('epoch: ', epoch, "time:", self.epoch_time)