def test():
    model = PalmNet().cpu()
    model.eval()
    # TODO 0 超参数区域
    lr = 1e-2

    # ###############################

    # TODO 1 构建模型 数据加载 损失函数
    if not os.path.exists(model_name):
        traceback.print_exc('Please choose a right model path!!')
    else:
        checkpoint = torch.load(model_name, map_location='cpu')
        model.load_state_dict(checkpoint['state_dict'])
        print('==>loaded model:', model_name)

    palm_data_test = PalmData(train_mode='test')
    testDataLoader = torch.utils.data.DataLoader(palm_data_test,
                                                 batch_size=1,
                                                 num_workers=1)

    acc_avg = 0
    loss_avg = 0
    f1_avg = 0
    precision_avg = 0
    recall_avg = 0
    for i, data in enumerate(testDataLoader):
        image = data['img']
        # print(type(image))
        label = data['child_cls']
        # print(type(label))
        image, label = image.to(device), label.to(device)
        pred = model(image)

        loss = wce_loss(pred, label)
        acc = my_acc_score(pred, label)
        f1 = my_f1_score(pred, label)
        precision = my_precision_score(pred, label)
        recall = my_recall_score(pred, label)

        acc_avg += acc
        f1_avg += f1
        precision_avg += precision
        recall_avg += recall
        print(
            '[%d / %d] The Loss:[%.6f], Acc:[%.6f]'
            'F1:[%.6f] RECALL:[%.6f] PRECISION:[%.6f]' %
            (i, len(testDataLoader), loss.item(), acc, f1, recall, precision))
        print('The avg Acc:[%.3f]'
              'F1:[%.3f] RECALL:[%.3f] PRECISION:[%.3f] is :' %
              (acc_avg / (i + 1), f1_avg / (i + 1), recall_avg /
               (i + 1), precision_avg / (i + 1)))
        print('The pred: ', np.where((pred.detach().numpy()) > 0, 1, 0))
        print('The Gt: ', label.detach().numpy().astype('int'))
        print('=================================\n')
def train():
    # TODO 0 超参数区域
    batch_size = 6
    lr = 1e-2

    # ###############################

    # TODO 1 构建模型 数据加载 损失函数 优化器
    model = PalmNet().cuda()

    # writer.add_graph(model, (torch.ones(1,3,512,512).cuda()))
    palm_data_train = PalmData()
    trainDataLoader = torch.utils.data.DataLoader(palm_data_train, batch_size=batch_size, num_workers=3)

    # optimizer = optim.Adam(params=model.parameters(),lr=lr)

    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20,40,60,80], gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    # TODO 1 构建模型 数据加载 损失函数
    # if not os.path.exists(model_name):
    #     traceback.print_exc('Please choose a right model path!!')
    # else:
    #     checkpoint = torch.load(model_name, map_location='cpu')
    #     model.load_state_dict(checkpoint['state_dict'])
    #     start_epoch = checkpoint['epoch']
    #     optimizer.load_state_dict(checkpoint['opt_state_dict'])
    #     print('==>loaded model:', model_name)

    num_epochs = 200
    min_loss = 99999
    for epoch in range(0, num_epochs):
        scheduler.step(epoch)
        loss_avg = 0
        acc_avg = 0
        f1_avg = 0
        precision_avg = 0
        recall_avg = 0
        for i, (name, param) in enumerate(model.named_parameters()):
            if 'bn' not in name:
                writer.add_histogram(name, param, epoch)

        for i, data in enumerate(trainDataLoader):
            # print(data['img'].shape)
            # print(data['father_cls'].shape)
            model.train()
            model.zero_grad()
            optimizer.zero_grad()

            image = data['img']
            # print(type(image))
            label = data['child_cls']
            # print(type(label))
            image, label = image.to(device), label.to(device)

            pred = model(image)
            loss = wce_loss(pred, label)
            acc = my_acc_score(pred, label)
            f1 = my_f1_score(pred, label)
            precision = my_precision_score(pred, label)
            recall = my_recall_score(pred, label)
            writer.add_scalar('loss', loss.item(), global_step=epoch * len(trainDataLoader) + i)
            writer.add_scalar('acc', acc, global_step=epoch * len(trainDataLoader) + i)

            writer.add_scalar('f1', f1, global_step=epoch * len(trainDataLoader) + i)
            writer.add_scalar('precision', precision, global_step=epoch * len(trainDataLoader) + i)

            writer.add_scalar('recall', recall, global_step=epoch * len(trainDataLoader) + i)

            print('epoch [%d],[%d / %d] The Loss:[%.6f], Acc:[%.6f]'
                  'F1:[%.6f] RECALL:[%.6f] PRECISION:[%.6f] Lr:[%.6f]  ' %
                  (epoch, i, len(trainDataLoader), loss.item(), acc, f1, recall, precision, scheduler.get_lr()[0]))
            loss.backward()
            optimizer.step()
            loss_avg += loss.item()
            acc_avg += acc
            f1_avg +=f1
            precision_avg+=precision
            recall_avg+=recall

        writer.add_scalars('avg_loss_f1_acc_precision_recall', {'loos': loss_avg / len(trainDataLoader),
                                                                'f1': f1_avg / len(trainDataLoader),
                                                                'acc': acc_avg/len(trainDataLoader),
                                                                'precision': precision_avg/len(trainDataLoader),
                                                                'recall':recall_avg/len(trainDataLoader)},global_step=epoch)
        writer.add_scalar('lr',scheduler.get_lr()[0],global_step=epoch)

        checkpoint = {
            'state_dict': model.state_dict(),
            'opt_state_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        if loss_avg/len(trainDataLoader) < min_loss:
            min_loss = loss_avg/len(trainDataLoader)
            torch.save(checkpoint,
                       './save_cls_model/min_0219_model_epoch_%d_%.6f.pt' % (epoch, loss_avg / len(trainDataLoader)))

        if epoch%10==0:
            torch.save(checkpoint,
                       './save_cls_model/0219model_epoch_%d_%.6f.pt' % (epoch, loss_avg / len(trainDataLoader)))