예제 #1
0
def train(epoch, reward_history):
    optimizer.zero_grad()
    policy.train()
    scheduler.step()
    # dataset = data_loader._reinit_dataset()
    tbar = tqdm(dataset)
    total_loss = 0
    P = TP = N = TN = 0.001
    for i, data in enumerate(tbar):
        policy.init_hidden(batch_size=1)
        # policy.repackage_hidden()
        '''
        data["A"] is (b, 3, h, w)
        data["A_gray"] is (b, 1, h, w)
        '''
        # init score
        Rewards = []
        Dists = []
        Actions = []
        scores = [get_score(data["A"], data["mask"])]

        # action is a list of prob.
        # input = torch.cat([data["A"], data["A_gray"]], dim=1).cuda()
        input = data["A"].cuda()
        # input = data["A_gray"].cuda()
        action, dist = policy(input)
        Dists.append(dist)
        Actions.append(action)

        gan.set_input(data)
        # visuals = OrderedDict([('real_A', utils_gan.tensor2im(data["A"])), ('fake_B', utils_gan.tensor2im(data["B"]))]) # in case not go into enlighten
        while action > 0 and len(scores) <= N_limit:
            # GAN prediction ###########################
            visuals, fake_B_tensor = gan.predict()
            # Seg & NIQE reward ######################################
            scores.append(get_score(fake_B_tensor, data["mask"]))
            # Reward ####################################
            reward = get_reward(scores, action)
            reward_history[0] = reward_history[0] * reward_history[
                1] + reward * (1 - reward_history[1])
            Rewards.append(gamma * (reward - reward_history[0]))
            tbar.set_description(
                'policy: %s, action: %s, reward: %.1f, reward_history: %.1f, #Enlighten: %d'
                % (str(dist.probs.data), str(
                    action.data), reward, reward_history[0], len(scores) - 1))
            # GAN reset image data #########################
            data = image_cycle(data, visuals)
            gan.set_input(data)
            # Policy ########################################
            # input = torch.cat([data["A"], data["A_gray"]], dim=1).cuda()
            input = data["A"].cuda()
            # input = data["A_gray"].cuda()
            action, dist = policy(input)
            Dists.append(dist)
            Actions.append(action)
        # GAN prediction ###########################
        visuals, fake_B_tensor = gan.predict()
        # Seg & NIQE reward ######################################
        scores.append(get_score(fake_B_tensor, data["mask"]))
        # Reward ####################################
        reward = get_reward(scores, action)
        reward_history[0] = reward_history[0] * reward_history[1] + reward * (
            1 - reward_history[1])
        Rewards.append(gamma * (reward - reward_history[0]))
        tbar.set_description(
            'policy: %s, action: %s, reward: %.1f, reward_history: %.1f, #Enlighten: %d'
            % (str(dist.probs.data), str(
                action.data), reward, reward_history[0], len(scores) - 1))

        # back-propagate the hybrid loss
        loss = 0
        for idx in range(len(Rewards)):
            # loss += (lambd_entropy * Dists[idx].entropy())
            loss += (-Dists[idx].log_prob(Actions[idx]) * Rewards[idx])
            if Rewards[idx] > 0:
                P += 1
                if Actions[idx] > 0: TP += 1
            elif Rewards[idx] < 0:
                N += 1
                if Actions[idx] < 1: TN += 1
            else: pass
            # print(Dists[idx].entropy(), - Dists[idx].log_prob(Actions[idx]) * Rewards[idx])
        torch.autograd.backward(loss)

        total_loss += loss.data.detach().cpu().numpy()
        if i > 0 and (i + 1) % update_interval == 0:
            optimizer.step()
            optimizer.zero_grad()
            writer.add_scalars(
                'IoU', {
                    'loss': loss.data.detach().cpu().numpy(),
                    'TP': 1. * TP / P,
                    'TN': 1. * TN / N
                },
                epoch * len(tbar) + i)
            total_loss = 0
            P = TP = N = TN = 0.001

        img_path = gan.get_image_paths()
        # print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

        # if args.gate_type == 'rnn':
        #     # for memory efficiency
        #     hidden = repackage_hidden(hidden)

    utils_seg.save_checkpoint(
        {
            'epoch': epoch + 1,
            'state_dict': policy.state_dict(),
            'optimizer': optimizer.state_dict(),
            # 'best_pred': self.best_pred,
        },
        opt_seg,
        True)
    return reward_history
예제 #2
0
    def validation(self, epoch=None):
        # Fast test during the training
        def eval_batch(model, image, target):
            r,g,b = image[:, 0, :, :]+1, image[:, 1, :, :]+1, image[:, 2, :, :]+1
            gray = 1. - (0.299*r+0.587*g+0.114*b)/2. # h, w
            gray = gray.unsqueeze(1)
            with torch.no_grad(): fake_B, _, _ = gan.netG_A.forward(image, gray)
            outputs = self.model(fake_B.clamp(-1, 1))

            # Gathers tensors from different GPUs on a specified device
            # outputs = gather(outputs, 0, dim=0)

            pred = outputs[0]
            pred = F.upsample(pred, size=(target.size(1), target.size(2)), mode='bilinear')

            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(pred.data, target, self.nclass)

            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        name2inter = {}; name2union = {}
        tbar = tqdm(self.valloader, desc='\r')

        for i, (image, target, name, class_freq) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            name2inter[name[0]] = inter.tolist()
            name2union[name[0]] = union.tolist()

            tbar.set_description('pixAcc: %.2f, mIoU: %.2f' % (pixAcc, mIoU))
        self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
        self.writer.add_scalars('IoU', {'validation iou': mIoU}, epoch)
        with open("name2inter", 'w') as fp:
            json.dump(name2inter, fp)
        with open("name2union", 'w') as fp:
            json.dump(name2union, fp)

        # cm = self.confusion_matrix_weather.get_scores()['cm']
        # self.logger.info(str(cm))
        # self.confusion_matrix_weather.reset()
        # cm = self.confusion_matrix_timeofday.get_scores()['cm']
        # self.logger.info(str(cm))
        # self.confusion_matrix_timeofday.reset()

        if epoch is not None:
            new_pred = (pixAcc + mIoU) / 2
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                utils.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)
예제 #3
0
    def validation(self, epoch=None):
        # Fast test during the training
        size_p = (1000, 1000)
        sub_batch_size = 5

        def eval_batch(model, image, target):
            if image.size(2) * image.size(3) <= 2250000:  # 1500x1500
                outputs = model(image)
                # Gathers tensors from different GPUs on a specified device
                # outputs = gather(outputs, 0, dim=0)
                pred = outputs[0]
                pred = F.upsample(
                    pred,
                    size=(target.size(1), target.size(2)),
                    mode='bilinear'
                )  # if you downsampled the input image due to large size
                correct, labeled = utils.batch_pix_accuracy(pred.data, target)
                inter, union = utils.batch_intersection_union(
                    pred.data, target, self.nclass)
                return correct, labeled, inter, union
            else:
                patches, coordinates, sizes = global2patch(image, size_p)
                predicted_patches = [
                    torch.zeros(len(coordinates[i]), self.nclass, size_p[0],
                                size_p[1]) for i in range(len(image))
                ]
                for i in range(len(image)):
                    j = 0
                    while j < len(coordinates[i]):
                        outputs = model(patches[i][j:j + sub_batch_size])[0]
                        predicted_patches[i][j:j + outputs.size()[0]] = outputs
                        j += sub_batch_size
                pred = patch2global(
                    predicted_patches, self.nclass, sizes, coordinates,
                    size_p)  # merge softmax scores from patches (overlaps)
                inter, union, correct, labeled = 0, 0, 0, 0
                for i in range(len(image)):
                    correct_tmp, labeled_tmp = utils.batch_pix_accuracy(
                        pred[i].unsqueeze(0), target[i])
                    inter_tmp, union_tmp = utils.batch_intersection_union(
                        pred[i].unsqueeze(0), target[i], self.nclass)
                    correct += correct_tmp
                    labeled += labeled_tmp
                    inter += inter_tmp
                    union += union_tmp
                return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        name2inter = {}
        name2union = {}
        tbar = tqdm(self.valloader, desc='\r')

        for i, (image, target, name, class_freq) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            name2inter[name[0]] = inter.tolist()
            name2union[name[0]] = union.tolist()

            tbar.set_description('pixAcc: %.2f, mIoU: %.2f' % (pixAcc, mIoU))
        self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
        self.writer.add_scalars('IoU', {'validation iou': mIoU}, epoch)
        with open("name2inter", 'w') as fp:
            json.dump(name2inter, fp)
        with open("name2union", 'w') as fp:
            json.dump(name2union, fp)

        torch.cuda.empty_cache()

        if epoch is not None:
            new_pred = (pixAcc + mIoU) / 2
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, self.args, is_best)
예제 #4
0
    def validation(self, epoch=None):
        # Fast test during the training
        # def eval_batch(model, image, target, weather, timeofday, scene):
        def eval_batch(model, image, target):
            # outputs, weather_o, timeofday_o = model(image)
            outputs = model(image)

            # Gathers tensors from different GPUs on a specified device
            # outputs = gather(outputs, 0, dim=0)

            pred = outputs[0]

            # create weather / timeofday target mask #######################
            # b, _, h, w = weather_o.size()
            # weather_t = torch.ones((b, h, w)).long()
            # for bi in range(b): weather_t[bi] *= weather[bi]
            # timeofday_t = torch.ones((b, h, w)).long()
            # for bi in range(b): timeofday_t[bi] *= timeofday[bi]
            ################################################################
            # self.confusion_matrix_weather.update([ m.astype(np.int64) for m in weather_t.numpy() ], weather_o.cpu().numpy().argmax(1))
            # self.confusion_matrix_timeofday.update([ m.astype(np.int64) for m in timeofday_t.numpy() ], timeofday_o.cpu().numpy().argmax(1))

            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)

            # correct_weather, labeled_weather = utils.batch_pix_accuracy(weather_o.data, weather_t)
            # correct_timeofday, labeled_timeofday = utils.batch_pix_accuracy(timeofday_o.data, timeofday_t)

            # return correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        total_correct_weather = 0
        total_label_weather = 0
        total_correct_timeofday = 0
        total_label_timeofday = 0
        name2inter = {}
        name2union = {}
        tbar = tqdm(self.valloader, desc='\r')

        # for i, (image, target, weather, timeofday, scene, name) in enumerate(tbar):
        for i, (image, target, name) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                # correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    # correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene)
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            name2inter[name[0]] = inter.tolist()
            name2union[name[0]] = union.tolist()

            # total_correct_weather += correct_weather
            # total_label_weather += labeled_weather
            # pixAcc_weather = 1.0 * total_correct_weather / (np.spacing(1) + total_label_weather)
            # total_correct_timeofday += correct_timeofday
            # total_label_timeofday += labeled_timeofday
            # pixAcc_timeofday = 1.0 * total_correct_timeofday / (np.spacing(1) + total_label_timeofday)

            # tbar.set_description('pixAcc: %.2f, mIoU: %.2f, weather: %.2f, timeofday: %.2f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday))
            tbar.set_description('pixAcc: %.2f, mIoU: %.2f' % (pixAcc, mIoU))
        # self.logger.info('pixAcc: %.3f, mIoU: %.3f, pixAcc_weather: %.3f, pixAcc_timeofday: %.3f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday))
        self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
        with open("name2inter", 'w') as fp:
            json.dump(name2inter, fp)
        with open("name2union", 'w') as fp:
            json.dump(name2union, fp)

        # cm = self.confusion_matrix_weather.get_scores()['cm']
        # self.logger.info(str(cm))
        # self.confusion_matrix_weather.reset()
        # cm = self.confusion_matrix_timeofday.get_scores()['cm']
        # self.logger.info(str(cm))
        # self.confusion_matrix_timeofday.reset()

        if epoch is not None:
            new_pred = (pixAcc + mIoU) / 2
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                utils.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': self.model.module.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'best_pred': self.best_pred,
                    }, self.args, is_best)
def train(epoch, reward_history, r_neg, r_pos):
    optimizer.zero_grad()
    policy.train()
    scheduler.step()
    # dataset = data_loader._reinit_dataset()
    tbar = tqdm(dataset)
    total_loss = 0
    P = N = 0
    mIoU_gain = []
    niqe_gain = []
    rewards = torch.zeros(opt_gan.batchSize).cuda()
    total_dists = torch.empty(0, 4).cuda()
    for i, data in enumerate(tbar):
        # policy.init_hidden(batch_size=1)
        # policy.repackage_hidden()
        '''
        data["A"] is (b, 3, h, w)
        data["A_gray"] is (b, 1, h, w)
        '''
        # init score
        rewards.resize_(data["A"].size(0)).copy_(torch.zeros(
            data["A"].size(0)))

        # action is a list of prob. ##########################
        # input = data["A_gray_border"].cuda()
        # input = data["A_border"].cuda()
        # input = torch.cat([data["A_border"], data["A_gray_border"]], dim=1).cuda()
        # input = torch.cat([data["A_border"].cuda(), data["A_gray_border"].cuda(), seg(data["A_border"].cuda())[0]], dim=1)
        # actions, dists, _ = policy(input)
        with torch.no_grad():
            seg_A = seg(data["A"].cuda())[0]
        actions, dists, _ = policy(data["A"].cuda(), data["A_Lab"].cuda(),
                                   seg_A)
        total_dists = torch.cat([total_dists, dists.probs], dim=0)

        with torch.no_grad():
            for j in range(data["A"].size(0)):
                A = data["A"][j:j + 1]
                A_gray = data["A_gray"][j:j + 1]
                score0 = get_score(A, data["mask"][j])
                niqe0 = get_niqe(A)
                gan.set_input_A(A, A_gray)
                gan.set_input_A_origin(A)
                if actions[j] > 0:
                    for _ in range(actions[j]):
                        # GAN prediction ###########################
                        _, fake_B_tensor = gan.predict()
                        # GAN reset image data #########################
                        A = fake_B_tensor.clamp(-1, 1)
                        r, g, b = A[0:1, 0:1] + 1, A[0:1, 1:2] + 1, A[0:1,
                                                                      2:3] + 1
                        A_gray = 1. - (0.299 * r + 0.587 * g +
                                       0.114 * b) / 2.  # h, w
                        gan.set_input_A(A, A_gray)
                    # Seg & NIQE reward ######################################
                    score1 = get_score(A, data["mask"][j])
                    # Reward ####################################
                    # reward = get_reward([score0, score1])
                    # reward = score1 - score0
                    mIoU_gain.append(score1 - score0)
                    niqe1 = get_niqe(A)
                    niqe_gain.append(niqe1 - niqe0)
                    reward = np.clip(score1 - score0, -r_neg, r_pos)
                    # reward_history[0] = reward_history[0] * reward_history[1] + reward * (1 - reward_history[1])
                    rewards[j] = gamma * (reward - reward_history[0])
                else:
                    _, fake_B_tensor = gan.predict()
                    # GAN reset image data #########################
                    A = fake_B_tensor.clamp(-1, 1)
                    # Seg & NIQE reward ######################################
                    score1 = get_score(A, data["mask"][j])
                    # Reward ####################################
                    # reward = get_reward([score1, score0])
                    # reward = score0 - score1
                    mIoU_gain.append(0)
                    niqe_gain.append(0)
                    reward = np.clip(score0 - score1, -r_pos, r_neg)
                    # reward_history[0] = reward_history[0] * reward_history[1] + reward * (1 - reward_history[1])
                    rewards[j] = gamma * (reward - reward_history[0])
                tbar.set_description(
                    'policy: %s, action: %s, reward: %.1f, reward_history: %.1f'
                    % (str(torch.round(dists.probs.data[j] * 10**2) / 10**2),
                       str(actions.data[j]), rewards[j], reward_history[0]))

        # back-propagate the hybrid loss
        loss = 0
        for idx in range(len(rewards)):
            loss -= (lambd_entropy * dists.entropy()[idx])
            loss += (-dists.log_prob(actions[idx])[idx] * rewards[idx])
            if rewards[idx] > 0: P += 1
            elif rewards[idx] <= 0: N += 1
            else: pass
            # print(Dists[idx].entropy(), - Dists[idx].log_prob(Actions[idx]) * Rewards[idx])
        torch.autograd.backward(loss)

        total_loss += loss.data.detach().cpu().numpy()
        if i > 0 and (i + 1) % update_interval == 0:
            optimizer.step()
            optimizer.zero_grad()
            writer.add_scalars('Loss',
                               {'loss': loss.data.detach().cpu().numpy()},
                               epoch * len(tbar) + i)
            writer.add_scalars('Policy', {
                'P': 1. * P / (P + N),
                'N': 1. * N / (P + N)
            },
                               epoch * len(tbar) + i)
            writer.add_scalars('mIoU', {"mIoU_gain": np.mean(mIoU_gain)},
                               epoch * len(tbar) + i)
            writer.add_scalars('NIQE', {"NIQE_gain": np.mean(niqe_gain)},
                               epoch * len(tbar) + i)
            mIoU_gain = []
            niqe_gain = []
            total_loss = 0
            P = N = 0

        if i > 0 and (i + 1) % update_reward_interval == 0:
            infer_actions = total_dists.argmax(1)
            mode_action = infer_actions.mode()[0]
            percent = (infer_actions
                       == mode_action).sum().float() / len(infer_actions)
            print("policy collapse:", np.round(percent, 3), mode_action)
            print("r_pos:", r_pos, "r_neg:", r_neg, "==>")
            if percent >= 0.55:
                delta = (percent - 0.55) / 2.
                # if mode_action > 0: r_neg += 0.1; r_pos -= 0.1; r_pos = max(r_pos, 0)
                # else: r_pos += 0.1; r_neg -= 0.1; r_neg = max(r_neg, 0)
                if mode_action > 0:
                    r_neg += delta
                    r_pos -= delta
                    r_pos = max(r_pos, 0.1)
                else:
                    r_pos += delta
                    r_neg -= delta
                    r_neg = max(r_neg, 0.1)
            print("r_pos:", r_pos, "r_neg:", r_neg)
            total_dists = torch.empty(0, 4).cuda()

        # img_path = gan.get_image_paths()
        # print('process image... %s' % img_path)
        # visualizer.save_images(webpage, visuals, img_path)

    utils_seg.save_checkpoint(
        {
            'epoch': epoch + 1,
            'state_dict': policy.state_dict(),
            'optimizer': optimizer.state_dict(),
            # 'best_pred': self.best_pred,
        },
        opt_seg,
        True)
    return reward_history, r_neg, r_pos