Пример #1
0
def train(Config,
          model,
          epoch_num,
          start_epoch,
          optimizer,
          exp_lr_scheduler,
          data_loader,
          save_dir,
          data_size=448,
          savepoint=5000,
          checkpoint=5000):
    # savepoint: save without evalution
    # checkpoint: save with evaluation

    eval_train_flag = False
    rec_loss = []
    checkpoint_list = []

    train_batch_size = data_loader['train'].batch_size
    train_epoch_step = data_loader['train'].__len__()
    train_loss_recorder = LossRecord(train_batch_size)

    if savepoint > train_epoch_step:
        savepoint = 1 * train_epoch_step
        checkpoint = savepoint

    date_suffix = dt()
    log_file = open(
        os.path.join(
            Config.log_folder,
            'formal_log_r50_dcl_%s_%s.log' % (str(data_size), date_suffix)),
        'a')

    add_loss = nn.L1Loss()
    get_ce_loss = nn.CrossEntropyLoss()
    get_ce_sig_loss = nn.BCELoss()
    get_focal_loss = FocalLoss()
    get_angle_loss = AngleLoss()
    step = 0

    for epoch in range(start_epoch, epoch_num - 1):
        exp_lr_scheduler.step(epoch)
        model.train(True)

        save_grad = []
        for batch_cnt, data in enumerate(data_loader['train']):
            step += 1
            loss = 0
            model.train(True)

            if Config.use_backbone:
                inputs, labels, img_names = data
                inputs = Variable(inputs.cuda())
                # labels = Variable(torch.LongTensor(np.array(labels)).cuda())
                labels = Variable(torch.FloatTensor(np.array(labels)).cuda())

            if Config.use_dcl:
                inputs, labels, labels_swap, swap_law, law_index, img_names = data

                inputs = Variable(inputs.cuda())

                # print (type(labels))
                # labels = Variable(torch.LongTensor(np.array(labels)).cuda())
                labels = Variable(torch.FloatTensor(np.array(labels)).cuda())

                #######  dy modify
                # labels_numpy = np.array(labels.cpu()).astype(np.uint8)
                # print (labels_numpy)

                labels_swap = Variable(
                    torch.LongTensor(np.array(labels_swap)).cuda())
                swap_law = Variable(
                    torch.LongTensor(np.array(swap_law)).float().cuda())

            optimizer.zero_grad()

            if inputs.size(0) < 2 * train_batch_size:
                outputs = model(inputs, inputs[0:-1:2])
            else:
                outputs = model(inputs, law_index)

            idx_unswap = torch.tensor([0, 2, 4, 6, 8], dtype=torch.long).cuda()
            unswap_label = torch.index_select(labels, dim=0, index=idx_unswap)

            # print (inputs.size(0))

            if Config.use_focal_loss:
                ce_loss = get_focal_loss(outputs[0], labels)
            else:

                # ce_loss = get_ce_loss(outputs[0], labels)      ###  classification batach x 200
                # print (outputs[0].shape)
                # print (unswap_label.shape)
                ce_loss = get_ce_sig_loss(
                    outputs[0], unswap_label)  ###  classification batach x 200

            if Config.use_Asoftmax:
                fetch_batch = labels.size(0)
                if batch_cnt % (train_epoch_step // 5) == 0:
                    angle_loss = get_angle_loss(outputs[3],
                                                labels[0:fetch_batch:2],
                                                decay=0.9)
                else:
                    angle_loss = get_angle_loss(outputs[3],
                                                labels[0:fetch_batch:2])
                loss += angle_loss

            alpha_ = 1
            loss += ce_loss * alpha_

            beta_ = 0.1
            gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1

            if Config.use_dcl:
                swap_loss = get_ce_loss(
                    outputs[1], labels_swap
                ) * beta_  ### adverisal classification  batach x 2
                loss += swap_loss  #######  0.692 * 0.1 = 0.0692
                law_loss = add_loss(
                    outputs[2], swap_law
                ) * gamma_  ### mask L1Loss batach x 49   L1 Loss 主要用来计算 input x 和 target y 的逐元素间差值的平均绝对值.
                loss += law_loss  #######  0.0683 * 1 = 0.0683

            loss.backward()
            torch.cuda.synchronize()

            optimizer.step()
            torch.cuda.synchronize()

            if Config.use_dcl:
                print(
                    'epoch:{:d}, globalstep: {:-8d},  {:d} / {:d} \n loss=ce_l+swap_l+law_l: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '
                    .format(epoch, step, batch_cnt, train_epoch_step,
                            loss.detach().item(),
                            ce_loss.detach().item(),
                            swap_loss.detach().item(),
                            law_loss.detach().item()),
                    flush=True)
            if Config.use_backbone:
                print(
                    'step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '
                    .format(step, train_epoch_step,
                            loss.detach().item(),
                            ce_loss.detach().item()),
                    flush=True)
            rec_loss.append(loss.detach().item())

            train_loss_recorder.update(loss.detach().item())

            # evaluation & save
            if step % checkpoint == 0:
                rec_loss = []
                print(32 * '-', flush=True)
                print(
                    'step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'
                    .format(step, train_epoch_step,
                            1.0 * step / train_epoch_step, epoch,
                            train_loss_recorder.get_val()),
                    flush=True)
                print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)

                val_acc = eval_turn(Config, model, data_loader['trainval'],
                                    'val', epoch, log_file)

                # if val_acc >0.9:
                #     checkpoint = 500
                #     savepoint = 500
                # save_path = os.path.join(save_dir, 'weights_%d_%d_%.4f_%.4f.pth'%(epoch, batch_cnt, val_acc1, val_acc3))
                save_path = os.path.join(
                    save_dir,
                    'weights_%d_%d_%.4f.pth' % (epoch, batch_cnt, val_acc))

                torch.cuda.synchronize()
                torch.save(model.state_dict(), save_path)
                print('saved model to %s' % (save_path), flush=True)
                torch.cuda.empty_cache()

            # save only
            elif step % savepoint == 0:
                train_loss_recorder.update(rec_loss)
                rec_loss = []
                save_path = os.path.join(
                    save_dir, 'savepoint_weights-%d-%s.pth' % (step, dt()))

                checkpoint_list.append(save_path)
                if len(checkpoint_list) == 6:
                    os.remove(checkpoint_list[0])
                    del checkpoint_list[0]
                torch.save(model.state_dict(), save_path)
                torch.cuda.empty_cache()

    log_file.close()
Пример #2
0
def train(Config,
          model,
          epoch_num,
          start_epoch,
          optimizer,
          exp_lr_scheduler,
          data_loader,
          save_dir,
          sw,
          data_size=448,
          savepoint=500,
          checkpoint=1000):
    # savepoint: save without evalution
    # checkpoint: save with evaluation

    best_prec1 = 0.

    step = 0
    eval_train_flag = False
    rec_loss = []
    checkpoint_list = []

    train_batch_size = data_loader['train'].batch_size
    train_epoch_step = data_loader['train'].__len__()
    train_loss_recorder = LossRecord(train_batch_size)

    if savepoint > train_epoch_step:
        savepoint = 1 * train_epoch_step
        checkpoint = savepoint

    date_suffix = dt()
    # log_file = open(os.path.join(Config.log_folder, 'formal_log_r50_dcl_%s_%s.log'%(str(data_size), date_suffix)), 'a')

    add_loss = nn.L1Loss()
    get_ce_loss = nn.CrossEntropyLoss()
    get_loss1 = Loss_1()
    get_focal_loss = FocalLoss()
    get_angle_loss = AngleLoss()

    for epoch in range(start_epoch, epoch_num - 1):
        optimizer.step()
        exp_lr_scheduler.step(epoch)
        model.train(True)

        save_grad = []
        for batch_cnt, data in enumerate(data_loader['train']):
            step += 1
            loss = 0
            model.train(True)
            if Config.use_backbone:
                inputs, labels, img_names = data
                inputs = Variable(inputs.cuda())
                labels = Variable(torch.from_numpy(np.array(labels)).cuda())

            if Config.use_dcl:
                if Config.multi:
                    inputs, labels, labels_swap, swap_law, blabels, clabels, tlabels, img_names = data
                else:
                    inputs, labels, labels_swap, swap_law, img_names = data
                inputs = Variable(inputs.cuda())
                labels = Variable(torch.from_numpy(np.array(labels)).cuda())
                labels_swap = Variable(
                    torch.from_numpy(np.array(labels_swap)).cuda())
                swap_law = Variable(
                    torch.from_numpy(np.array(swap_law)).float().cuda())
                if Config.multi:
                    blabels = Variable(
                        torch.from_numpy(np.array(blabels)).cuda())
                    clabels = Variable(
                        torch.from_numpy(np.array(clabels)).cuda())
                    tlabels = Variable(
                        torch.from_numpy(np.array(tlabels)).cuda())

            optimizer.zero_grad()

            # 显示输入图片
            # sw.add_image('attention_image', inputs[0])

            if inputs.size(0) < 2 * train_batch_size:
                outputs = model(inputs, inputs[0:-1:2])
            else:
                outputs = model(inputs, None)
            if Config.multi:
                if Config.use_loss1:
                    b_loss, pro_b = get_loss1(outputs[2], blabels)
                    # 关联品牌标签和车型
                    t_loss, _ = get_loss1(outputs[4],
                                          tlabels,
                                          brand_prob=pro_b)
                    s_loss, pro_s = get_loss1(outputs[0],
                                              labels,
                                              brand_prob=pro_b)
                    c_loss, _ = get_loss1(outputs[3], clabels)
                    ce_loss = b_loss + t_loss + s_loss + c_loss * 0.2
                else:
                    ce_loss = get_ce_loss(outputs[0], labels) + get_ce_loss(
                        outputs[0], blabels) + get_ce_loss(
                            outputs[0], clabels) + get_ce_loss(
                                outputs[0], tlabels)
            else:
                if Config.use_focal_loss:
                    ce_loss = get_focal_loss(outputs[0], labels)
                else:
                    if Config.use_loss1:
                        # 直接内部组合两个loss
                        ce_loss_1, pro = get_loss1(outputs[0], labels)
                        ce_loss = 0
                    else:
                        ce_loss = get_ce_loss(outputs[0], labels)

            if Config.use_Asoftmax:
                fetch_batch = labels.size(0)
                if batch_cnt % (train_epoch_step // 5) == 0:
                    angle_loss = get_angle_loss(outputs[3],
                                                labels[0:fetch_batch:2],
                                                decay=0.9)
                else:
                    angle_loss = get_angle_loss(outputs[3],
                                                labels[0:fetch_batch:2])
                loss += angle_loss

            loss += ce_loss

            alpha_ = 1
            beta_ = 1
            # gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
            gamma_ = 0.01
            if Config.use_dcl:
                if Config.use_focal_loss:
                    swap_loss = get_focal_loss(outputs[1], labels_swap) * beta_
                else:
                    if Config.use_loss1:
                        swap_loss, _ = get_loss1(outputs[1],
                                                 labels_swap,
                                                 brand_prob=pro_s)
                    else:
                        swap_loss = get_ce_loss(outputs[1],
                                                labels_swap) * beta_
                loss += swap_loss
                if not Config.no_loc:
                    law_loss = add_loss(outputs[2], swap_law) * gamma_
                    loss += law_loss

            loss.backward()
            torch.cuda.synchronize()

            torch.cuda.synchronize()

            if Config.use_dcl:
                if Config.multi:
                    print(
                        'step: {:-8d} / {:d}  loss: {:6.4f}  ce_loss: {:6.4f} swap_loss: {:6.4f} '
                        .format(step, train_epoch_step,
                                loss.detach().item(),
                                ce_loss.detach().item(),
                                swap_loss.detach().item()),
                        flush=True)
                # if Config.use_loss1:
                #     print(
                #         'step: {:-8d} / {:d}  loss: {:6.4f}  ce_loss: {:6.4f} swap_loss: {:6.4f} '.format(step,train_epoch_step,loss.detach().item(),ce_loss.detach().item(),swap_loss.detach().item()),
                #         flush=True)
                elif Config.no_loc:
                    print(
                        'step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} '
                        .format(step, train_epoch_step,
                                loss.detach().item(),
                                ce_loss.detach().item(),
                                swap_loss.detach().item()),
                        flush=True)
                else:
                    print(
                        'step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '
                        .format(step, train_epoch_step,
                                loss.detach().item(),
                                ce_loss.detach().item(),
                                swap_loss.detach().item(),
                                law_loss.detach().item()),
                        flush=True)
            if Config.use_backbone:
                print(
                    'step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '
                    .format(step, train_epoch_step,
                            loss.detach().item(),
                            ce_loss.detach().item()),
                    flush=True)
            rec_loss.append(loss.detach().item())

            train_loss_recorder.update(loss.detach().item())

            # evaluation & save
            if step % checkpoint == 0:
                rec_loss = []
                print(32 * '-', flush=True)
                print(
                    'step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'
                    .format(step, train_epoch_step,
                            1.0 * step / train_epoch_step, epoch,
                            train_loss_recorder.get_val()),
                    flush=True)
                print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
                if Config.multi:
                    val_acc_s, val_acc_b, val_acc_c, val_acc_t = eval_turn(
                        Config, model, data_loader['val'], 'val', epoch)
                    is_best = val_acc_s > best_prec1
                    best_prec1 = max(val_acc_s, best_prec1)
                    filename = 'weights_%d_%d_%.4f_%.4f.pth' % (
                        epoch, batch_cnt, val_acc_s, val_acc_b)
                    save_checkpoint(model.state_dict(), is_best, save_dir,
                                    filename)
                    sw.add_scalar("Train_Loss/Total_loss",
                                  loss.detach().item(), epoch)
                    sw.add_scalar("Train_Loss/b_loss",
                                  b_loss.detach().item(), epoch)
                    sw.add_scalar("Train_Loss/t_loss",
                                  t_loss.detach().item(), epoch)
                    sw.add_scalar("Train_Loss/s_loss",
                                  s_loss.detach().item(), epoch)
                    sw.add_scalar("Train_Loss/c_loss",
                                  c_loss.detach().item(), epoch)
                    sw.add_scalar("Accurancy/val_acc_s", val_acc_s, epoch)
                    sw.add_scalar("Accurancy/val_acc_b", val_acc_b, epoch)
                    sw.add_scalar("Accurancy/val_acc_c", val_acc_c, epoch)
                    sw.add_scalar("Accurancy/val_acc_t", val_acc_t, epoch)
                    sw.add_scalar("learning_rate",
                                  exp_lr_scheduler.get_lr()[1], epoch)
                else:
                    val_acc1, val_acc2, val_acc3 = eval_turn(
                        Config, model, data_loader['val'], 'val', epoch)
                    is_best = val_acc1 > best_prec1
                    best_prec1 = max(val_acc1, best_prec1)
                    filename = 'weights_%d_%d_%.4f_%.4f.pth' % (
                        epoch, batch_cnt, val_acc1, val_acc3)
                    save_checkpoint(model.state_dict(), is_best, save_dir,
                                    filename)
                    sw.add_scalar("Train_Loss", loss.detach().item(), epoch)
                    sw.add_scalar("Val_Accurancy", val_acc1, epoch)
                    sw.add_scalar("learning_rate",
                                  exp_lr_scheduler.get_lr()[1], epoch)
                torch.cuda.empty_cache()

            # save only
            elif step % savepoint == 0:
                train_loss_recorder.update(rec_loss)
                rec_loss = []
                save_path = os.path.join(
                    save_dir, 'savepoint_weights-%d-%s.pth' % (step, dt()))

                checkpoint_list.append(save_path)
                if len(checkpoint_list) == 6:
                    os.remove(checkpoint_list[0])
                    del checkpoint_list[0]
                torch.save(model.state_dict(), save_path)
                torch.cuda.empty_cache()
Пример #3
0
def train(Config,
          model,
          epoch_num,
          start_epoch,
          optimizer,
          exp_lr_scheduler,
          data_loader,
          save_dir,
          data_size=448,
          savepoint=500,
          checkpoint=1000):
    # savepoint: save without evalution
    # checkpoint: save with evaluation

    step = 0
    eval_train_flag = False
    rec_loss = []
    checkpoint_list = []

    train_batch_size = data_loader['train'].batch_size
    train_epoch_step = data_loader['train'].__len__()
    train_loss_recorder = LossRecord(train_batch_size)

    if savepoint > train_epoch_step:
        savepoint = 1 * train_epoch_step
        checkpoint = savepoint

    date_suffix = dt()
    log_file = open(
        os.path.join(
            Config.log_folder,
            'formal_log_r50_dcl_%s_%s.log' % (str(data_size), date_suffix)),
        'a')

    add_loss = nn.L1Loss()
    get_ce_loss = nn.CrossEntropyLoss()
    get_focal_loss = FocalLoss()
    get_angle_loss = AngleLoss()

    for epoch in range(start_epoch, epoch_num - 1):
        exp_lr_scheduler.step(epoch)
        model.train(True)

        save_grad = []
        for batch_cnt, data in enumerate(data_loader['train']):
            step += 1
            loss = 0
            model.train(True)
            if Config.use_backbone:
                inputs, labels, img_names = data
                inputs = Variable(inputs.cuda())
                labels = Variable(torch.from_numpy(np.array(labels)).cuda())

            if Config.use_dcl:
                inputs, labels, labels_swap, swap_law, img_names = data

                inputs = Variable(inputs.cuda())
                labels = Variable(torch.from_numpy(np.array(labels)).cuda())
                labels_swap = Variable(
                    torch.from_numpy(np.array(labels_swap)).cuda())
                swap_law = Variable(
                    torch.from_numpy(np.array(swap_law)).float().cuda())

            optimizer.zero_grad()

            if inputs.size(0) < 2 * train_batch_size:
                outputs = model(inputs, inputs[0:-1:2])
            else:
                outputs = model(inputs, None)

            if Config.use_focal_loss:
                ce_loss = get_focal_loss(outputs[0], labels)
            else:
                ce_loss = get_ce_loss(outputs[0], labels)

            if Config.use_Asoftmax:
                fetch_batch = labels.size(0)
                if batch_cnt % (train_epoch_step // 5) == 0:
                    angle_loss = get_angle_loss(outputs[3],
                                                labels[0:fetch_batch:2],
                                                decay=0.9)
                else:
                    angle_loss = get_angle_loss(outputs[3],
                                                labels[0:fetch_batch:2])
                loss += angle_loss

            loss += ce_loss

            alpha_ = 1
            beta_ = 1
            gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
            if Config.use_dcl:
                swap_loss = get_ce_loss(outputs[1], labels_swap) * beta_
                loss += swap_loss
                law_loss = add_loss(outputs[2], swap_law) * gamma_
                loss += law_loss

            loss.backward()
            torch.cuda.synchronize()

            optimizer.step()
            torch.cuda.synchronize()

            if Config.use_dcl:
                print(
                    'step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} + {:6.4f} + {:6.4f} '
                    .format(step, train_epoch_step,
                            loss.detach().item(),
                            ce_loss.detach().item(),
                            swap_loss.detach().item(),
                            law_loss.detach().item()),
                    flush=True)
            if Config.use_backbone:
                print(
                    'step: {:-8d} / {:d} loss=ce_loss+swap_loss+law_loss: {:6.4f} = {:6.4f} '
                    .format(step, train_epoch_step,
                            loss.detach().item(),
                            ce_loss.detach().item()),
                    flush=True)
            rec_loss.append(loss.detach().item())

            train_loss_recorder.update(loss.detach().item())

            # evaluation & save
            if step % checkpoint == 0:
                rec_loss = []
                print(32 * '-', flush=True)
                print(
                    'step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'
                    .format(step, train_epoch_step,
                            1.0 * step / train_epoch_step, epoch,
                            train_loss_recorder.get_val()),
                    flush=True)
                print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
                if eval_train_flag:
                    trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(
                        model, data_loader['trainval'], 'trainval', epoch,
                        log_file)
                    if abs(trainval_acc1 - trainval_acc3) < 0.01:
                        eval_train_flag = False

                val_acc1, val_acc2, val_acc3 = eval_turn(
                    model, data_loader['val'], 'val', epoch, log_file)

                save_path = os.path.join(
                    save_dir, 'weights_%d_%d_%.4f_%.4f.pth' %
                    (epoch, batch_cnt, val_acc1, val_acc3))
                torch.cuda.synchronize()
                torch.save(model.state_dict(), save_path)
                print('saved model to %s' % (save_path), flush=True)
                torch.cuda.empty_cache()

            # save only
            elif step % savepoint == 0:
                train_loss_recorder.update(rec_loss)
                rec_loss = []
                save_path = os.path.join(
                    save_dir, 'savepoint_weights-%d-%s.pth' % (step, dt()))

                checkpoint_list.append(save_path)
                if len(checkpoint_list) == 6:
                    os.remove(checkpoint_list[0])
                    del checkpoint_list[0]
                torch.save(model.state_dict(), save_path)
                torch.cuda.empty_cache()

    log_file.close()
Пример #4
0
def train(Config,
          model,
          epoch_num,
          start_epoch,
          optimizer,
          exp_lr_scheduler,
          data_loader,
          save_dir,
          data_size=448,
          savepoint=500,
          checkpoint=1000):
    # savepoint: save without evalution
    # checkpoint: save with evaluation
    bmy_weight = 1.0  # 1.5 # 决定品牌分支在学习中权重
    step = 0
    eval_train_flag = False
    rec_loss = []
    checkpoint_list = []

    steps = np.array([], dtype=np.int)
    train_accs = np.array([], dtype=np.float32)
    test_accs = np.array([], dtype=np.float32)
    ce_losses = np.array([], dtype=np.float32)
    ce_loss_mu = -1
    ce_loss_std = 0.0

    train_batch_size = data_loader['train'].batch_size
    train_epoch_step = data_loader['train'].__len__()
    train_loss_recorder = LossRecord(train_batch_size)

    if savepoint > train_epoch_step:
        savepoint = 1 * train_epoch_step
        checkpoint = savepoint

    date_suffix = dt()
    log_file = open(
        os.path.join(
            Config.log_folder,
            'formal_log_r50_dcl_%s_%s.log' % (str(data_size), date_suffix)),
        'a')

    add_loss = nn.L1Loss()
    get_ce_loss = nn.CrossEntropyLoss()
    get_focal_loss = FocalLoss()
    get_angle_loss = AngleLoss()

    for epoch in range(start_epoch, epoch_num - 1):
        model.train(True)
        save_grad = []
        for batch_cnt, data in enumerate(data_loader['train']):
            step += 1
            loss = 0
            model.train(True)
            if Config.use_backbone:
                inputs, brand_labels, img_names, bmy_labels = data
                inputs = Variable(inputs.cuda())
                brand_labels = Variable(
                    torch.from_numpy(np.array(brand_labels)).cuda())
                bmy_labels = Variable(
                    torch.from_numpy(np.array(bmy_labels)).cuda())

            if Config.use_dcl:
                inputs, brand_labels, brand_labels_swap, swap_law, img_names, bmy_labels = data
                org_brand_labels = brand_labels
                inputs = Variable(inputs.cuda())
                brand_labels = Variable(
                    torch.from_numpy(np.array(brand_labels)).cuda())
                bmy_labels = Variable(
                    torch.from_numpy(np.array(bmy_labels)).cuda())
                brand_labels_swap = Variable(
                    torch.from_numpy(np.array(brand_labels_swap)).cuda())
                swap_law = Variable(
                    torch.from_numpy(np.array(swap_law)).float().cuda())

            optimizer.zero_grad()

            if inputs.size(0) < 2 * train_batch_size:
                outputs = model(inputs, inputs[0:-1:2])
            else:
                outputs = model(inputs, None)

            if Config.use_focal_loss:
                ce_loss_brand = get_focal_loss(outputs[0], brand_labels)
                ce_loss_bmy = get_focal_loss(outputs[-1], bmy_labels)
            else:
                ce_loss_brand = get_ce_loss(outputs[0], brand_labels)
                ce_loss_bmy = get_ce_loss(outputs[-1], bmy_labels)
            ce_loss = ce_loss_brand + bmy_weight * ce_loss_bmy

            if Config.use_Asoftmax:
                fetch_batch = brand_labels.size(0)
                if batch_cnt % (train_epoch_step // 5) == 0:
                    angle_loss = get_angle_loss(outputs[3],
                                                brand_labels[0:fetch_batch:2],
                                                decay=0.9)
                else:
                    angle_loss = get_angle_loss(outputs[3],
                                                brand_labels[0:fetch_batch:2])
                loss += angle_loss

            loss += ce_loss
            ce_loss_val = ce_loss.detach().item()
            ce_losses = np.append(ce_losses, ce_loss_val)

            alpha_ = 1
            beta_ = 1
            gamma_ = 0.01 if Config.dataset == 'STCAR' or Config.dataset == 'AIR' else 1
            if Config.use_dcl:
                swap_loss = get_ce_loss(outputs[1], brand_labels_swap) * beta_
                loss += swap_loss
                law_loss = add_loss(outputs[2], swap_law) * gamma_
                loss += law_loss

            loss.backward()
            torch.cuda.synchronize()
            optimizer.step()
            exp_lr_scheduler.step(epoch)
            torch.cuda.synchronize()

            if Config.use_dcl:
                if ce_loss_mu > 0 and ce_loss_val > ce_loss_mu + 3.0 * ce_loss_std:
                    # 记录下这个批次,可能是该批次有标注错误情况
                    print('记录可疑批次信息: loss={0}; threshold={1};'.format(
                        ce_loss_val, ce_loss_mu + 2.0 * ce_loss_std))
                    with open(
                            './logs/abnormal_samples_{0}_{1}_{2}.txt'.format(
                                epoch, step, ce_loss_val), 'a+') as fd:
                        error_batch_len = len(img_names)
                        for i in range(error_batch_len):
                            fd.write('{0} <=> {1};\r\n'.format(
                                org_brand_labels[i * 2], img_names[i]))
                print('epoch{}: step: {:-8d} / {:d} loss=ce_loss+'
                      'swap_loss+law_loss: {:6.4f} = {:6.4f} '
                      '+ {:6.4f} + {:6.4f} brand_loss: {:6.4f}'.format(
                          epoch, step % train_epoch_step, train_epoch_step,
                          loss.detach().item(), ce_loss_val,
                          swap_loss.detach().item(),
                          law_loss.detach().item(),
                          ce_loss_brand.detach().item()),
                      flush=True)

            if Config.use_backbone:
                print('epoch{}: step: {:-8d} / {:d} loss=ce_loss+'
                      'swap_loss+law_loss: {:6.4f} = {:6.4f} '.format(
                          epoch, step % train_epoch_step, train_epoch_step,
                          loss.detach().item(),
                          ce_loss.detach().item()),
                      flush=True)
            rec_loss.append(loss.detach().item())

            train_loss_recorder.update(loss.detach().item())

            # evaluation & save
            if step % checkpoint == 0:
                rec_loss = []
                print(32 * '-', flush=True)
                print(
                    'step: {:d} / {:d} global_step: {:8.2f} train_epoch: {:04d} rec_train_loss: {:6.4f}'
                    .format(step, train_epoch_step,
                            1.0 * step / train_epoch_step, epoch,
                            train_loss_recorder.get_val()),
                    flush=True)
                print('current lr:%s' % exp_lr_scheduler.get_lr(), flush=True)
                '''
                if eval_train_flag:
                    trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file)
                    if abs(trainval_acc1 - trainval_acc3) < 0.01:
                        eval_train_flag = False
                '''
                print('##### validate dataset #####')
                trainval_acc1, trainval_acc2, trainval_acc3 = eval_turn(
                    Config, model, data_loader['val'], 'val', epoch, log_file
                )  #eval_turn(Config, model, data_loader['trainval'], 'trainval', epoch, log_file)
                print('##### test dataset #####')
                val_acc1, val_acc2, val_acc3 = trainval_acc1, trainval_acc2, \
                            trainval_acc3 # eval_turn(Config, model, data_loader['val'], 'val', epoch, log_file)
                steps = np.append(steps, step)
                train_accs = np.append(train_accs, trainval_acc1)
                test_accs = np.append(test_accs, val_acc1)

                save_path = os.path.join(
                    save_dir, 'weights_%d_%d_%.4f_%.4f.pth' %
                    (epoch, batch_cnt, val_acc1, val_acc3))
                torch.cuda.synchronize()
                torch.save(model.state_dict(),
                           save_path,
                           _use_new_zipfile_serialization=False)
                print('saved model to %s' % (save_path), flush=True)
                torch.cuda.empty_cache()
                # 保存精度等信息并初始化
                ce_loss_mu = ce_losses.mean()
                ce_loss_std = ce_losses.std()
                print('Cross entropy loss: mu={0}; std={1}; range:{2}~{3};'.
                      format(ce_loss_mu, ce_loss_std,
                             ce_loss_mu - 3.0 * ce_loss_std,
                             ce_loss_mu + 3.0 * ce_loss_std))
                ce_losses = np.array([], dtype=np.float32)
                if train_accs.shape[0] > 30:
                    np.savetxt('./logs/steps1.txt', (steps, ))
                    np.savetxt('./logs/train_accs1.txt', (train_accs, ))
                    np.savetxt('./logs/test_accs1.txt', (test_accs, ))
                    steps = np.array([], dtype=np.int)
                    train_accs = np.array([], dtype=np.float32)
                    test_accs = np.array([], dtype=np.float32)

            # save only
            elif step % savepoint == 0:
                train_loss_recorder.update(rec_loss)
                rec_loss = []
                save_path = os.path.join(
                    save_dir, 'savepoint_weights-%d-%s.pth' % (step, dt()))

                checkpoint_list.append(save_path)
                if len(checkpoint_list) == 6:
                    os.remove(checkpoint_list[0])
                    del checkpoint_list[0]
                torch.save(model.state_dict(),
                           save_path,
                           _use_new_zipfile_serialization=False)
                torch.cuda.empty_cache()

    log_file.close()