def train(model_in, num_epochs=3, load_model=True, freeze_extractor=True):
    # === dataloader defination ===
    train_batch_size = 2
    valid_batch_size = 1
    test_batch_size = 1
    dataloaders = VSLdataset.create_dataloader_train_valid_test(
        train_batch_size, valid_batch_size, test_batch_size)
    train_dataloader = dataloaders['train']
    valid_dataloader = dataloaders['valid']
    test_dataloader = dataloaders['test']
    # =============================

    # === every n epochs print ===
    valid_epoch_step = 1
    test_epoch_step = 10
    # ============================

    # === got model ===
    save_file = os.path.join('../saved_model',
                             'CLSTM_18_l10_h512_d02_final.pth')
    writer = SummaryWriter(
        '../saved_model/tensorboard_log_18_l10_h512_d02_final')
    if (load_model == True):
        model = load_checkpoint(model_in, save_file)
    else:
        model = model_in
    # =================
    print(model)
    # === freeze some layers against overfitting ===
    if (freeze_extractor == True):
        for layer_id, child in enumerate(model.children()):
            if layer_id < 8:  #layer 9 is the fc
                for param in child.parameters():
                    param.requires_grad = False
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=0.0001)  #lr=0.0000000001)
    else:
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
    # ==============================================

    loss_function = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     verbose=1,
                                                     patience=2)

    # === runing no gpu ===
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    torch.backends.cudnn.benchmark = True
    # =====================
    # epoch = 1

    for epoch in range(num_epochs):
        # training
        model.train()
        train_loss = 0.0
        print('Train:')

        for index, (data, target) in enumerate(train_dataloader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)

            loss = loss_function(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if index % 500 == 499:  # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, index + 1, train_loss / 500))
                writer.add_scalar('Train/Loss', train_loss / 500,
                                  epoch * (len(train_dataloader)) + index + 1)
                writer.flush()
                train_loss = 0.0

        if (epoch % valid_epoch_step == (valid_epoch_step - 1)):
            # validation
            class_correct = list(
                0. for i in range(len(VSLdataset.class_name_to_id_)))
            class_total = list(
                0. for i in range(len(VSLdataset.class_name_to_id_)))
            class_name = list(VSLdataset.class_name_to_id_.keys())
            model.eval()
            print('Valid:')
            loss_eval = 0.0
            loss_for_display = 0.0

            all_targets = np.zeros((len(valid_dataloader), 1))
            all_predicted_flatten = np.zeros((len(valid_dataloader), 1))

            for index_eval, (data_eval,
                             target_eval) in enumerate(valid_dataloader):
                data_eval, target_eval = data_eval.to(device), target_eval.to(
                    device)
                output_eval = model(data_eval)
                loss_i = loss_function(output_eval, target_eval).item()
                loss_for_display += loss_i
                loss_eval += loss_i

                all_targets[
                    index_eval, :] = target_eval[0].cpu().detach().numpy()
                #all_scores[index_eval, :] = output_eval[0].cpu().detach().numpy()

                _, predicted = torch.max(output_eval, 1)
                all_predicted_flatten[
                    index_eval, :] = predicted[0].cpu().detach().numpy()

                c = (predicted == target_eval).squeeze()
                for i in range(valid_batch_size):
                    try:
                        label = target_eval[i]
                        class_correct[label] += c[i].item()
                    except:
                        label = target_eval
                        class_correct[label] += c.item()
                    class_total[label] += 1
                if index_eval % 10 == 9:  # print every 10 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, index_eval + 1, loss_for_display / 10))
                    loss_for_display = 0.0

            print("F1:")
            evaluate.calculate_f1(all_targets, all_predicted_flatten)

            for i in range(len(VSLdataset.class_name_to_id_)):
                accuracy = 100 * (class_correct[i] + 1) / (class_total[i] + 1)
                print('Accuracy of %5s : %2d %%' % (class_name[i], accuracy))
                # Record loss and accuracy from the test run into the writer
                writer.add_scalar('Valid/Accuracy ' + str(class_name[i]),
                                  accuracy, epoch)
                writer.flush()
            print('avg_loss: ', loss_eval / len(valid_dataloader))
            scheduler.step(loss_eval / len(valid_dataloader))
            writer.add_scalar('Valid/Loss ', loss_eval / len(valid_dataloader),
                              epoch)
            writer.flush()
            loss_eval = 0.0

            # 每次 eval 都进行保存
            # save current model
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'epoch': epoch,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss
                }, save_file)
            print("saved model.")
def infer(model_in):
    # === dataloader defination ===
    train_batch_size = 1
    valid_batch_size = 1
    test_batch_size = 1
    dataloaders = VSLdataset.create_dataloader_train_valid_test(
        train_batch_size, valid_batch_size, test_batch_size)
    #dataloaders = VSLdataset.create_dataloader_valid(valid_batch_size)

    valid_dataloader = dataloaders['valid']
    # =============================

    # === got model ===
    save_file = os.path.join('../saved_model',
                             'CLSTM_50_l10_h512_loss021_best.pth')
    model = load_checkpoint(model_in, save_file)
    # =================
    print(model)

    # === runing no gpu ===
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    torch.backends.cudnn.benchmark = True
    # =====================

    loss_function = nn.CrossEntropyLoss()

    # validation
    class_correct = list(0. for i in range(len(VSLdataset.class_name_to_id_)))
    class_total = list(0. for i in range(len(VSLdataset.class_name_to_id_)))
    class_name = list(VSLdataset.class_name_to_id_.keys())
    model.eval()
    print('Test:')

    all_targets = np.zeros((len(valid_dataloader), 1))
    all_scores = np.zeros((len(valid_dataloader), 8))
    all_predicted_flatten = np.zeros((len(valid_dataloader), 1))

    loss_eval = 0.0
    for index_eval, (data_eval, target_eval) in enumerate(valid_dataloader):
        data_eval, target_eval = data_eval.to(device), target_eval.to(device)
        output_eval = model(data_eval)

        loss_i = loss_function(output_eval, target_eval).item()
        loss_eval += loss_i

        all_targets[index_eval, :] = target_eval[0].cpu().detach().numpy()
        all_scores[index_eval, :] = output_eval[0].cpu().detach().numpy()

        _, predicted = torch.max(output_eval, 1)
        all_predicted_flatten[
            index_eval, :] = predicted[0].cpu().detach().numpy()
        if (predicted != target_eval):  # batch_size, timesteps, C, H, W
            print('mis_classified: ', index_eval)
            #visualize_mis_class(data_eval[0].permute(0, 2, 3, 1).cpu(), str(index_eval) + '.png', class_name[target_eval[0].cpu().numpy()], class_name[predicted[0].cpu().numpy()])

        c = (predicted == target_eval).squeeze()
        for i in range(valid_batch_size):
            try:
                label = target_eval[i]
                class_correct[label] += c[i].item()
            except:
                label = target_eval
                class_correct[label] += c.item()
            class_total[label] += 1

    for i in range(len(VSLdataset.class_name_to_id_)):
        accuracy = 100 * (class_correct[i] + 1) / (class_total[i] + 1)
        print('Accuracy of %5s : %2d %%' % (class_name[i], accuracy))

    print('avg_loss: ', loss_eval / len(valid_dataloader))

    # === draw roc and confusion mat ===
    evaluate.draw_roc_bin(all_targets, all_scores)
    evaluate.draw_confusion_matrix(all_targets, all_predicted_flatten)
Exemplo n.º 3
0
def infer(model_in, file_name, fp16=False, int8=False):
    # === dataloader defination ===
    train_batch_size = 1
    valid_batch_size = 1
    test_batch_size = 1
    dataloaders = VSLdataset.create_dataloader_train_valid_test(
        train_batch_size, valid_batch_size, test_batch_size)

    valid_dataloader = dataloaders['valid']
    # =============================

    save_file = os.path.join('../saved_model', file_name)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    model = model_in.eval().cuda()
    data_example = torch.randn(8, 10, 3, 224, 224)
    data_example = data_example.to(device)

    model_trt = torch2trt(model_in, [data_example],
                          max_workspace_size=1 << 26,
                          fp16_mode=fp16,
                          int8_mode=int8)
    model_trt.load_state_dict(torch.load(save_file))

    loss_function = nn.CrossEntropyLoss()

    # validation
    class_correct = list(0. for i in range(len(VSLdataset.class_name_to_id_)))
    class_total = list(0. for i in range(len(VSLdataset.class_name_to_id_)))
    class_name = list(VSLdataset.class_name_to_id_.keys())

    print('Test:')

    all_targets = np.zeros((len(valid_dataloader), 1))
    all_scores = np.zeros((len(valid_dataloader), 8))
    all_predicted_flatten = np.zeros((len(valid_dataloader), 1))

    loss_eval = 0.0
    for index_eval, (data_eval, target_eval) in enumerate(valid_dataloader):
        data_eval, target_eval = data_eval.to(device), target_eval.to(device)
        output_eval = model_trt(data_eval)

        loss_i = loss_function(output_eval, target_eval).item()
        loss_eval += loss_i

        all_targets[index_eval, :] = target_eval[0].cpu().detach().numpy()
        all_scores[index_eval, :] = output_eval[0].cpu().detach().numpy()

        _, predicted = torch.max(output_eval, 1)
        all_predicted_flatten[
            index_eval, :] = predicted[0].cpu().detach().numpy()
        if (predicted != target_eval):  # batch_size, timesteps, C, H, W
            print('mis_classified: ', index_eval)
            #visualize_mis_class(data_eval[0].permute(0, 2, 3, 1).cpu(), str(index_eval) + '.png', class_name[target_eval[0].cpu().numpy()], class_name[predicted[0].cpu().numpy()])

        c = (predicted == target_eval).squeeze()
        for i in range(valid_batch_size):
            try:
                label = target_eval[i]
                class_correct[label] += c[i].item()
            except:
                label = target_eval
                class_correct[label] += c.item()
            class_total[label] += 1

    for i in range(len(VSLdataset.class_name_to_id_)):
        accuracy = 100 * (class_correct[i] + 1) / (class_total[i] + 1)
        print('Accuracy of %5s : %2d %%' % (class_name[i], accuracy))

    print('avg_loss: ', loss_eval / len(valid_dataloader))

    # === draw roc and confusion mat ===
    evaluate.draw_roc_bin(all_targets, all_scores)
    evaluate.draw_confusion_matrix(all_targets, all_predicted_flatten)