Example #1
0
    def train(self, batch):
        T = batch["s"].size(0)
        self.optimizer.zero_grad()

        bht = batch.hist(T - 1)
        R = self.decide(bht["s"])["V"].squeeze()
        for i, terminal in enumerate(bht["terminal"]):
            if terminal:
                R[i] = 0.0

        err = None
        verr = None
        perr = None

        for t in range(T - 2, -1, -1):
            bht = batch.hist(t)
            state = self.forward(bht["s"])

            r = batch["r"][t]
            R = self.gamma * R + r
            for i, terminal in enumerate(bht["terminal"]):
                if terminal:
                    R[i] = 0.0

            V = state["V"].squeeze()

            coef = Variable(R - V.data)  #.data? -1?
            pi = state["pi"]
            a = bht["a"]

            log_pi = (pi + 1e-6).log()

            def bw_hook(grad_in):
                # this works only on pytorch 0.2.0
                return grad_in.mul(coef.view(-1, 1))

            #log_pi.register_hook(bw_hook)
            log_pi = log_pi.mul(coef.view(-1, 1))

            nlll = nn.NLLLoss()(log_pi, Variable(a))
            mse = nn.MSELoss()(V, Variable(R))

            verr = add_err(verr, mse.data[0])
            perr = add_err(perr, nlll.data[0])
            err = add_err(err, mse)
            err = add_err(err, nlll)

        self.counter.stats["vcost"].feed(verr / (T - 1))
        self.counter.stats["pcost"].feed(perr / (T - 1))
        self.counter.stats["cost"].feed(err.data[0] / (T - 1))

        err.backward()
        self.optimizer.step()
Example #2
0
    def update(self, mi, batch, stats):
        m = mi["model"]
        args = self.args

        T = batch["a"].size(0)
        total_predict_err = None

        hs = []

        for t in range(0, T - 1):
            # forwarded policy should be identical with current policy
            bht = batch.hist(t)
            state_curr = m(bht)
            if t > 0:
                prev_a = batch["a"][t - 1]
            h = state_curr["h"].data

            for i in range(0, t):
                future_pred = m.transition(hs[i], prev_a)
                pred_h = future_pred["hf"]
                this_err = self.prediction_loss(pred_h, Variable(h))
                total_predict_err = add_err(total_predict_err, this_err)
                hs[i] = pred_h

            term = Variable(1.0 - batch["terminal"][t].float()).view(-1, 1)
            for _h in hs:
                _h.register_hook(lambda grad: grad.mul(term))

            hs.append(Variable(h))

        stats["predict_err"].feed(total_predict_err.item())
        total_predict_err.backward()
Example #3
0
    def update(self, mi, batch, stats):
        m = mi["model"]
        args = self.args

        T = batch["a"].size(0)
        total_predict_err = None

        hs = []

        for t in range(0, T - 1):
            # forwarded policy should be identical with current policy
            bht = batch.hist(t)
            state_curr = m(bht)
            if t > 0:
                prev_a = batch["a"][t - 1]
            h = state_curr["h"].data

            for i in range(0, t):
                future_pred = m.transition(hs[i], prev_a)
                pred_h = future_pred["hf"]
                this_err = self.prediction_loss(pred_h, Variable(h))
                total_predict_err = add_err(total_predict_err, this_err)
                hs[i] = pred_h

            term = Variable(1.0 - batch["terminal"][t].float()).view(-1, 1)
            for _h in hs:
                _h.register_hook(lambda grad: grad.mul(term))

            hs.append(Variable(h))

        stats["predict_err"].feed(total_predict_err.data[0])
        total_predict_err.backward()
Example #4
0
    def update(self, mi, batch, stats):
        ''' Update given batch '''
        # Current timestep.
        state_curr = mi["model"](batch.hist(0))
        total_loss = None
        eps = 1e-6
        targets = batch.hist(0)["a"]
        for i, pred in enumerate(state_curr["a"]):
            if i == 0:
                prec1, prec5 = self.accuracy(pred.data, targets[:, i].contiguous(), topk=(1, 5))
                stats["top1_acc"].feed(prec1[0])
                stats["top5_acc"].feed(prec5[0])
            # backward.
            loss = self.policy_loss((pred + eps).log(), Variable(targets[:, i]))
            stats["loss" + str(i)].feed(loss.data[0])
            total_loss = add_err(total_loss, loss / (i + 1))

        stats["total_loss"].feed(total_loss.data[0])
        total_loss.backward()
Example #5
0
    def update(self, mi, batch, stats):
        ''' Update given batch '''
        # Current timestep.
        
        # print("\u001b[31;1m|py|\u001b[0m\u001b[37m", "MultiplePrediction::", inspect.currentframe().f_code.co_name)

        state_curr = mi["model"](batch)
        total_policy_loss = None
        eps = 1e-6
        targets = batch["offline_a"]
        if "pis" not in state_curr:
            state_curr["pis"] = [state_curr["pi"]]

        for i, pred in enumerate(state_curr["pis"]):
            if i == 0:
                prec1, prec5 = topk_accuracy(
                    pred.data, targets[:, i].contiguous(), topk=(1, 5))
                stats["top1_acc"].feed(prec1[0])
                stats["top5_acc"].feed(prec5[0])

            # backward.
            loss = self.policy_loss(
                (pred + eps).log(), Variable(targets[:, i]))
            stats["loss" + str(i)].feed(loss.data[0])
            total_policy_loss = add_err(total_policy_loss, loss / (i + 1))

        total_value_loss = None
        if "V" in state_curr and "winner" in batch:
            total_value_loss = self.value_loss(
                state_curr["V"], Variable(batch["winner"]))

        stats["total_policy_loss"].feed(total_policy_loss.data[0])
        if total_value_loss is not None:
            stats["total_value_loss"].feed(total_value_loss.data[0])
            total_loss = total_policy_loss + total_value_loss
        else:
            total_loss = total_policy_loss

        stats["total_loss"].feed(total_loss.data[0])
        if self.options.multipred_backprop:
            total_loss.backward()
Example #6
0
    def update(self, mi, batch, stats):
        ''' Update given batch '''
        # Current timestep.
        state_curr = mi["model"](batch.hist(0))
        total_loss = None
        eps = 1e-6
        targets = batch.hist(0)["offline_a"]
        for i, pred in enumerate(state_curr["pis"]):
            if i == 0:
                prec1, prec5 = topk_accuracy(pred.data, targets[:, i].contiguous(), topk=(1, 5))
                stats["top1_acc"].feed(prec1[0])
                stats["top5_acc"].feed(prec5[0])

            # backward.
            loss = self.policy_loss((pred + eps).log(), Variable(targets[:, i]))
            stats["loss" + str(i)].feed(loss.data[0])
            total_loss = add_err(total_loss, loss / (i + 1))

        stats["total_loss"].feed(total_loss.data[0])
        if not self.args.multipred_no_backprop:
            total_loss.backward()
Example #7
0
    def update(self, mi, batch, stats):
        ''' Update given batch '''
        # Current timestep.
        state_curr = mi["model"](batch)
        total_policy_loss = None
        eps = 1e-6
        targets = batch["offline_a"]
        if "pis" not in state_curr:
            state_curr["pis"] = [state_curr["pi"]]

        for i, pred in enumerate(state_curr["pis"]):
            if i == 0:
                prec1, prec5 = topk_accuracy(
                    pred.data, targets[:, i].contiguous(), topk=(1, 5))
                stats["top1_acc"].feed(prec1[0])
                stats["top5_acc"].feed(prec5[0])

            # backward.
            loss = self.policy_loss(
                (pred + eps).log(), Variable(targets[:, i]))
            stats["loss" + str(i)].feed(loss.data[0])
            total_policy_loss = add_err(total_policy_loss, loss / (i + 1))

        total_value_loss = None
        if "V" in state_curr and "winner" in batch:
            total_value_loss = self.value_loss(
                state_curr["V"], Variable(batch["winner"]))

        stats["total_policy_loss"].feed(total_policy_loss.data[0])
        if total_value_loss is not None:
            stats["total_value_loss"].feed(total_value_loss.data[0])
            total_loss = total_policy_loss + total_value_loss
        else:
            total_loss = total_policy_loss

        stats["total_loss"].feed(total_loss.data[0])
        if self.options.multipred_backprop:
            total_loss.backward()
Example #8
0
    def update(self, mi, batch, stats):
        ''' Actor critic model '''
        m = mi["model"]
        args = self.args

        T = batch["a"].size(0)

        state_curr = m(batch.hist(T - 1))

        self.discounted_reward.setR(state_curr["V"].squeeze().data, stats)

        next_h = state_curr["h"].data
        policies = [0] * T
        policies[T - 1] = state_curr["pi"].data

        for t in range(T - 2, -1, -1):
            bht = batch.hist(t)
            state_curr = m.forward(bht)

            # go through the sample and get the rewards.
            a = batch["a"][t]
            V = state_curr["V"].squeeze()

            R = self.discounted_reward.feed(dict(
                r=batch["r"][t], terminal=batch["terminal"][t]),
                                            stats=stats)

            pi = state_curr["pi"]
            policies[t] = pi.data

            overall_err = None

            if not args.fixed_policy:
                overall_err = self.pg.feed(R - V.data,
                                           state_curr,
                                           bht,
                                           stats,
                                           old_pi_s=bht)
                overall_err += self.value_matcher.feed(dict(V=V, target=R),
                                                       stats)

            if args.h_smooth:
                curr_h = state_curr["h"]
                # Block gradient
                curr_h = Variable(curr_h.data)
                future_pred = m.transition(curr_h, a)
                pred_h = future_pred["hf"]
                predict_err = self.prediction_loss(pred_h, Variable(next_h))
                overall_err = add_err(overall_err, predict_err)

                stats["predict_err"].feed(predict_err.data[0])

                if args.contrastive_V:
                    # Sample an action other than the current action.
                    prob = pi.data.clone().fill_(1 / (pi.size(1) - 1))
                    # Make the selected entry zero.
                    prob.scatter_(1, a.view(-1, 1), 0.0)
                    other_a = prob.multinomial(1)
                    other_future_pred = m.transition(curr_h, other_a)
                    other_pred_h = other_future_pred["hf"]

                    # Make sure the predicted values are lower than the gt
                    # one (we might need to add prob?)
                    # Stop the gradient.
                    pi_V = m.decision(pred_h.data)
                    pi_V_other = m.decision(other_pred_h.data)
                    all_one = R.clone().view(-1, 1).fill_(1.0)

                    rank_err = self.rank_loss(pi_V["V"], pi_V_other["V"],
                                              Variable(all_one))
                    value_err = self.prediction_loss(pi_V["V"], Variable(R))

                    stats["rank_err"].feed(rank_err.data[0])
                    stats["value_err"].feed(value_err.data[0])

                    overall_err = add_err(overall_err, rank_err)
                    overall_err = add_err(overall_err, value_err)

            if overall_err is not None:
                overall_err.backward()

            next_h = state_curr["h"].data

            if overall_err is not None:
                stats["cost"].feed(overall_err.data[0])
            #print("[%d]: reward=%.4f, sum_reward=%.2f, acc_reward=%.4f, value_err=%.4f, policy_err=%.4f" % (i, r.mean(), r.sum(), R.mean(), value_err.data[0], policy_err.data[0]))

        if args.h_match_policy or args.h_match_action:
            state_curr = m.forward(batch.hist(0))
            h = state_curr["h"]
            if args.fixed_policy:
                h = Variable(h.data)

            total_policy_err = None
            for t in range(0, T - 1):
                # forwarded policy should be identical with current policy
                V_pi = m.decision_fix_weight(h)
                a = batch["a"][t]
                pi_h = V_pi["pi"]

                # Nothing to learn when t = 0
                if t > 0:
                    if args.h_match_policy:
                        policy_err = self.policy_match_loss(
                            pi_h, Variable(policies[t]))
                        stats["policy_match_err%d" % t].feed(
                            policy_err.data[0])
                    elif args.h_match_action:
                        # Add normalization constant
                        logpi_h = (pi_h + args.min_prob).log()
                        policy_err = self.policy_max_action_loss(
                            logpi_h, Variable(a))
                        stats["policy_match_a_err%d" % t].feed(
                            policy_err.data[0])

                    total_policy_err = add_err(total_policy_err, policy_err)

                future_pred = m.transition(h, a)
                h = future_pred["hf"]

            total_policy_err.backward()
            stats["total_policy_match_err"].feed(total_policy_err.data[0])
Example #9
0
    def update(self, mi, batch, stats):
        ''' Actor critic model '''
        m = mi["model"]
        args = self.args

        T = batch["a"].size(0)

        state_curr = m(batch.hist(T - 1))

        self.discounted_reward.setR(state_curr["V"].squeeze().data, stats)

        next_h = state_curr["h"].data
        policies = [0] * T
        policies[T - 1] = state_curr["pi"].data

        for t in range(T - 2, -1, -1):
            bht = batch.hist(t)
            state_curr = m.forward(bht)

            # go through the sample and get the rewards.
            a = batch["a"][t]
            V = state_curr["V"].squeeze()

            R = self.discounted_reward.feed(
                dict(r=batch["r"][t], terminal=batch["terminal"][t]),
                stats=stats)

            pi = state_curr["pi"]
            policies[t] = pi.data

            overall_err = None

            if not args.fixed_policy:
                overall_err = self.pg.feed(R - V.data, state_curr, bht, stats, old_pi_s=bht)
                overall_err += self.value_matcher.feed(dict(V=V, target=R), stats)

            if args.h_smooth:
                curr_h = state_curr["h"]
                # Block gradient
                curr_h = Variable(curr_h.data)
                future_pred = m.transition(curr_h, a)
                pred_h = future_pred["hf"]
                predict_err = self.prediction_loss(pred_h, Variable(next_h))
                overall_err = add_err(overall_err, predict_err)

                stats["predict_err"].feed(predict_err.data[0])

                if args.contrastive_V:
                    # Sample an action other than the current action.
                    prob = pi.data.clone().fill_(1 / (pi.size(1) - 1))
                    # Make the selected entry zero.
                    prob.scatter_(1, a.view(-1, 1), 0.0)
                    other_a = prob.multinomial(1)
                    other_future_pred = m.transition(curr_h, other_a)
                    other_pred_h = other_future_pred["hf"]

                    # Make sure the predicted values are lower than the gt
                    # one (we might need to add prob?)
                    # Stop the gradient.
                    pi_V = m.decision(pred_h.data)
                    pi_V_other = m.decision(other_pred_h.data)
                    all_one = R.clone().view(-1, 1).fill_(1.0)

                    rank_err = self.rank_loss(pi_V["V"], pi_V_other["V"], Variable(all_one))
                    value_err = self.prediction_loss(pi_V["V"], Variable(R))

                    stats["rank_err"].feed(rank_err.data[0])
                    stats["value_err"].feed(value_err.data[0])

                    overall_err = add_err(overall_err, rank_err)
                    overall_err = add_err(overall_err, value_err)

            if overall_err is not None:
                overall_err.backward()

            next_h = state_curr["h"].data

            if overall_err is not None:
                stats["cost"].feed(overall_err.data[0])
            #print("[%d]: reward=%.4f, sum_reward=%.2f, acc_reward=%.4f, value_err=%.4f, policy_err=%.4f" % (i, r.mean(), r.sum(), R.mean(), value_err.data[0], policy_err.data[0]))

        if args.h_match_policy or args.h_match_action:
            state_curr = m.forward(batch.hist(0))
            h = state_curr["h"]
            if args.fixed_policy:
                h = Variable(h.data)

            total_policy_err = None
            for t in range(0, T - 1):
                # forwarded policy should be identical with current policy
                V_pi = m.decision_fix_weight(h)
                a = batch["a"][t]
                pi_h = V_pi["pi"]

                # Nothing to learn when t = 0
                if t > 0:
                    if args.h_match_policy:
                        policy_err = self.policy_match_loss(pi_h, Variable(policies[t]))
                        stats["policy_match_err%d" % t].feed(policy_err.data[0])
                    elif args.h_match_action:
                        # Add normalization constant
                        logpi_h = (pi_h + args.min_prob).log()
                        policy_err = self.policy_max_action_loss(logpi_h, Variable(a))
                        stats["policy_match_a_err%d" % t].feed(policy_err.data[0])

                    total_policy_err = add_err(total_policy_err, policy_err)

                future_pred = m.transition(h, a)
                h = future_pred["hf"]

            total_policy_err.backward()
            stats["total_policy_match_err"].feed(total_policy_err.data[0])