Ejemplo n.º 1
0
def efficient_test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        complexity = SI(data)
        projected_complexity = project_IC_func(complexity)

        batch_idx_output_list = []
        for batch_idx in range(int(projected_complexity.size()[0])):
            downsampling_size, _, _ = project_hyperparam(projected_complexity[batch_idx])
            # print(downsampling_size)
            projected_complexity
            single_img = data[batch_idx, :, :, :]
            single_img = single_img.unsqueeze(0)
            single_img = torch.nn.Upsample(size=(downsampling_size, downsampling_size),
                                           mode='nearest')(single_img)

            single_output = model(single_img)
            batch_idx_output_list.append(single_output)

        output = torch.cat(batch_idx_output_list, 0)
        test_loss += F.cross_entropy(output, target, size_average=False).data  # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))
Ejemplo n.º 2
0
def train(epoch, file):
    print('\nEpoch: %d' % epoch)
    net.train()
    writer = csv.writer(file)
    train_loss = 0
    correct = 0
    total = 0
    writer.writerow(["image_name", "image_name_copy", "image_complexity"])
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        print("Train Size: ")
        print(SI(inputs).size())
        print("Train tensor: ")
        si_score = SI(inputs).tolist()
        tar = targets.tolist()
        print(si_score)
        for idx in range(len(si_score) - 1):
            writer.writerow(
                [str(batch_idx) + "_" + str(idx), tar[idx], si_score[idx]])
Ejemplo n.º 3
0
def test(epoch, file):
    global best_acc
    net.eval()
    writer = csv.writer(file)
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        writer.writerow(["image_name", "image_name_copy", "image_complexity"])
        for batch_idx, (inputs, targets) in enumerate(testloader):
            print("Train Size: ")
            print(SI(inputs).size())
            print("Train tensor: ")
            si_score = SI(inputs).tolist()
            tar = targets.tolist()
            print(si_score)
            for idx in range(len(si_score) - 1):
                writer.writerow(
                    [str(batch_idx) + "_" + str(idx), tar[idx], si_score[idx]])
Ejemplo n.º 4
0
def validate(val_loader, model, criterion, file):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    writer = csv.writer(file)
    writer.writerow(["image_name", "image_name_copy", "image_complexity"])
    for i, (input, target) in enumerate(val_loader):
        print("Train Size: ")
        print(SI(input).size())
        print("Train tensor: ")
        si_score = SI(input).tolist()
        tar = target.tolist()
        print(si_score)
        for idx in range(len(si_score) - 1):
            writer.writerow([str(i) + "_" + str(idx), tar[idx], si_score[idx]])
Ejemplo n.º 5
0
def test(file):
    model.eval()
    test_loss = 0
    correct = 0
    writer = csv.writer(file)
    writer.writerow(["image_name", "image_name_copy", "image_complexity"])
    for i, (data, target) in enumerate(test_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        # print("Train Size: ")
        # print(SI(data).size())
        # print("Train tensor: ")
        si_score = SI(data).tolist()
        tar = target.tolist()
        # print(si_score)
        for idx in range(len(si_score) - 1):
            writer.writerow([str(i) + "_" + str(idx), tar[idx], si_score[idx]])
Ejemplo n.º 6
0
def train(epoch, file):
    train_loss = 0.0
    model.train()
    loss = 0
    global history_score
    avg_loss = 0.
    train_acc = 0.
    writer = csv.writer(file)
    writer.writerow(["image_name", "image_name_copy", "image_complexity"])
    for i, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        # print("Train Size: ")
        # print(SI(data).size())
        # print("Train tensor: ")
        si_score = SI(data).tolist()
        tar = target.tolist()
        # print(si_score)
        for idx in range(len(si_score) - 1):
            writer.writerow([str(i) + "_" + str(idx), tar[idx], si_score[idx]])
Ejemplo n.º 7
0
def train(epoch):
    train_loss = 0.0
    model.train()
    global history_score
    avg_loss = 0.
    train_acc = 0.
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        # data, target = Variable(data), Variable(target)
        complexity = SI(data)
        # print("complexity: ")
        # print(complexity)
        projected_complexity = project_IC_func(complexity)
        # cluster into three group
        indices_list, mean_list = cluster_func(projected_complexity)
        # print("indices_list")
        # print(indices_list)
        [indices_simple, indices_middle, indices_hard] = indices_list
        [mean_simple, mean_middle, mean_hard] = mean_list
        downsampling_size_simple, keep_ratio_simple, loss_scale_simple = project_hyperparam(mean_simple)
        downsampling_size_middle, keep_ratio_middle, loss_scale_middle = project_hyperparam(mean_middle)
        downsampling_size_hard, keep_ratio_hard, loss_scale_hard = project_hyperparam(mean_hard)

        images_simple = torch.index_select(data, 0, indices_simple)
        # print("images_simple define")
        # print(images_simple)
        images_middle = torch.index_select(data, 0, indices_middle)
        images_hard = torch.index_select(data, 0, indices_hard)

        target_simple = torch.index_select(target, 0, indices_simple)
        target_middle = torch.index_select(target, 0, indices_middle)
        target_hard = torch.index_select(target, 0, indices_hard)

        weight_dict = {"simple": loss_scale_simple, "middle": loss_scale_middle, "hard": loss_scale_hard}

        keep_prob_dict = {"simple": keep_ratio_simple, "middle": keep_ratio_middle, "hard": keep_ratio_hard}

        downsample_dict = {"simple": downsampling_size_simple, "middle": downsampling_size_middle,
                           "hard": downsampling_size_hard}

        #
        # print("image simple shape")
        # print(images_simple.shape[0])
        if (images_simple.shape[0]):
            image_pick_prob_tensor = torch.ones(images_simple.shape[0]).cuda()
            # print("image_pick_prob_tensor")
            # print(image_pick_prob_tensor)
            images_simple_keep_num = int(round(images_simple.shape[0] * keep_prob_dict["simple"]))
            # print("(images_simple.shape[0] * keep_prob_dict[simple])")
            # print((images_simple.shape[0] * keep_prob_dict["simple"]))
            # print(round(images_simple.shape[0] * keep_prob_dict["simple"]))
            # print("keep_prob_dict[simple]")
            # print(keep_prob_dict["simple"])
            # print("images_simple_keep_num")
            # print(images_simple_keep_num)
            if images_simple_keep_num:
                indices = torch.multinomial(image_pick_prob_tensor, images_simple_keep_num)
                # print("indices")
                # print(indices)
                images_simple = torch.index_select(images_simple, 0, indices)
                target_simple = torch.index_select(target_simple, 0, indices)
            else:
                images_simple = torch.zeros(0, 3, crop_size, crop_size).cuda()
                target_simple = torch.zeros(0, crop_size, crop_size).cuda()

        if (images_middle.shape[0]):
            image_pick_prob_tensor = torch.ones(images_middle.shape[0]).cuda()
            images_middle_keep_num = int(round(images_middle.shape[0] * keep_prob_dict["middle"]))
            if images_middle_keep_num:
                indices = torch.multinomial(image_pick_prob_tensor, images_middle_keep_num)
                images_middle = torch.index_select(images_middle, 0, indices)
                target_middle = torch.index_select(target_middle, 0, indices)
            else:
                images_middle = torch.zeros(0, 3, crop_size, crop_size).cuda()
                target_middle = torch.zeros(0, crop_size, crop_size).cuda()

        if (images_hard.shape[0]):
            image_pick_prob_tensor = torch.ones(images_hard.shape[0]).cuda()
            images_hard_keep_num = int(round(images_hard.shape[0] * keep_prob_dict["hard"]))
            if images_hard_keep_num:
                indices = torch.multinomial(image_pick_prob_tensor, images_hard_keep_num)
                images_hard = torch.index_select(images_hard, 0, indices)
                target_hard = torch.index_select(target_hard, 0, indices)
            else:
                images_hard = torch.zeros(0, 3, crop_size, crop_size).cuda()
                target_hard = torch.zeros(0, crop_size, crop_size).cuda()

        # check if the smaple number == 1, if so, drop them because of BN's requirement in training
        if (images_simple.shape[0] == 1):
            images_simple = torch.zeros(0, 3, crop_size, crop_size).cuda()
            target_simple = torch.zeros(0, crop_size, crop_size).cuda()
        if (images_middle.shape[0] == 1):
            images_middle = torch.zeros(0, 3, crop_size, crop_size).cuda()
            target_middle = torch.zeros(0, crop_size, crop_size).cuda()
        if (images_hard.shape[0] == 1):
            images_hard = torch.zeros(0, 3, crop_size, crop_size).cuda()
            target_hard = torch.zeros(0, crop_size, crop_size).cuda()

        # images_reorder = torch.cat((images_simple, images_middle, images_hard), 0)
        # target_reorder = torch.cat((target_simple, target_middle, target_hard), 0)

        run_flag_simple = False
        run_flag_middle = False
        run_flag_hard = False
        # print("image_simple before")
        # print(images_simple)
        if (images_simple.shape[0]):
            images_simple = torch.nn.Upsample(size=(downsample_dict["simple"], downsample_dict["simple"]),
                                              mode='nearest')(images_simple)
            # print(downsample_dict["simple"])
            # print("image simple after upsample")
            # print(images_simple)
            run_flag_simple = True
        if (images_middle.shape[0]):
            images_middle = torch.nn.Upsample(size=(downsample_dict["middle"], downsample_dict["middle"]),
                                              mode='nearest')(images_middle)
            run_flag_middle = True
        if (images_hard.shape[0]):
            images_hard = torch.nn.Upsample(size=(downsample_dict["hard"], downsample_dict["hard"]), mode='nearest')(
                images_hard)
            run_flag_hard = True

        # if (images_reorder.shape[0]):
        optimizer.zero_grad()
        if run_flag_simple:
            # print(images_simple.size())
            output_simple = model(images_simple)
            # print("simple image was updated to model")
            # output_simple = torch.nn.Upsample(size=(crop_size, crop_size), mode='nearest')(
            #     output_simple)
            loss_simple = F.cross_entropy(output_simple, target_simple)
        else:
            # output_simple = torch.zeros(0, nclass, crop_size, crop_size).cuda()
            loss_simple = torch.zeros(1, ).cuda()
        # print("weight_dict simple:")
        # print(weight_dict["simple"] is not None)
        if weight_dict["simple"] is not None:
            # print("weight simple before")
            # print(weight_dict["simple"])
            # print("weight simple images_simple before")
            # print(images_simple.shape[0])
            weight_simple = weight_dict["simple"] * images_simple.shape[0]
            # print("weight simple after ")
            # print(weight_simple)
        else:
            weight_simple = 0.0
        weighted_loss_simple = weight_simple * loss_simple

        if run_flag_middle:
            output_middle = model(images_middle)
            # print("middle image was updated to model")

            # output_middle = torch.nn.Upsample(size=(crop_size, crop_size), mode='nearest')(
            #     output_middle)
            loss_middle = F.cross_entropy(output_middle, target_middle)
        else:
            # output_middle = torch.zeros(0, nclass, crop_size, crop_size).cuda()
            loss_middle = torch.zeros(1, ).cuda()
        if weight_dict["middle"] is not None:
            weight_middle = weight_dict["middle"] * images_middle.shape[0]
        else:
            weight_middle = 0.0
        weighted_loss_middle = weight_middle * loss_middle

        if run_flag_hard:
            output_hard = model(images_hard)
            # print("hard image was updated to model")
            # output_hard = torch.nn.Upsample(size=(crop_size, crop_size), mode='nearest')(
            #     output_hard)
            # print("target_hard")
            # print(target_hard)
            loss_hard = F.cross_entropy(output_hard, target_hard)
        else:
            # output_hard = torch.zeros(0, nclass, crop_size, crop_size).cuda()
            loss_hard = torch.zeros(1, ).cuda()
        if weight_dict["hard"] is not None:
            weight_hard = weight_dict["hard"] * images_hard.shape[0]
        else:
            weight_hard = 0.0
        weighted_loss_hard = weight_hard * loss_hard

        # print(weight_simple)
        # print(weight_middle)
        # print(weight_hard)
        # output_reorder = torch.cat((output_simple, output_middle, output_hard), 0)
        loss = (weighted_loss_simple + weighted_loss_middle + weighted_loss_hard) / (
                weight_simple + weight_middle + weight_hard)
        # print("loss shape")
        # print(loss)
        train_loss += loss.item()
        loss.backward()

        if args.sr:
            updateBN()
        optimizer.step()