def __init__(self, load_dataset=True):

        super(BehavioralDistAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset,
                batch_sampler=self.val_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

            self.episodic_loader = torch.utils.data.DataLoader(
                self.full_dataset,
                sampler=self.episodic_sampler,
                batch_size=self.batch,
                num_workers=args.cpu_workers)

        if not self.wasserstein:
            self.loss_fn_vs = torch.nn.CrossEntropyLoss(size_average=True)
            self.loss_fn_qs = torch.nn.CrossEntropyLoss(size_average=True)
            self.loss_fn_vl = torch.nn.CrossEntropyLoss(size_average=True)
            self.loss_fn_ql = torch.nn.CrossEntropyLoss(size_average=True)
        else:
            self.loss_fn_vs = wasserstein_metric(support=args.atoms_short, n=1)
            self.loss_fn_qs = wasserstein_metric(support=args.atoms_short, n=1)
            self.loss_fn_vl = wasserstein_metric(support=args.atoms_long, n=1)
            self.loss_fn_ql = wasserstein_metric(support=args.atoms_long, n=1)

        self.histogram = torch.from_numpy(self.meta['histogram']).float()
        m = self.histogram.max()
        self.histogram = m / self.histogram
        self.histogram = torch.clamp(self.histogram, 0, 10).cuda()

        self.loss_fn_beta = torch.nn.CrossEntropyLoss(size_average=True,
                                                      weight=self.histogram)
        self.loss_fn_pi_s = torch.nn.CrossEntropyLoss(reduce=False,
                                                      size_average=True)
        self.loss_fn_pi_l = torch.nn.CrossEntropyLoss(reduce=False,
                                                      size_average=True)
        self.loss_fn_pi_s_tau = torch.nn.CrossEntropyLoss(reduce=False,
                                                          size_average=True)
        self.loss_fn_pi_l_tau = torch.nn.CrossEntropyLoss(reduce=False,
                                                          size_average=True)

        # alpha weighted sum

        self.alpha_b = 1  # 1 / 0.7

        self.alpha_vs = 1  # 1 / 0.02
        self.alpha_qs = 1

        self.alpha_vl = 1  # 1 / 0.02
        self.alpha_ql = 1

        self.alpha_pi_s = 1  # 1 / 0.02
        self.alpha_pi_l = 1

        self.alpha_pi_s_tau = 1  # 1 / 0.02
        self.alpha_pi_l_tau = 1

        self.model = BehavioralDistNet()
        self.model.cuda()

        # configure learning

        net_parameters = [
            p[1] for p in self.model.named_parameters() if "rn_" in p[0]
        ]
        vl_params = [
            p[1] for p in self.model.named_parameters() if "on_vl" in p[0]
        ]
        ql_params = [
            p[1] for p in self.model.named_parameters() if "on_ql" in p[0]
        ]
        vs_params = [
            p[1] for p in self.model.named_parameters() if "on_vs" in p[0]
        ]
        qs_params = [
            p[1] for p in self.model.named_parameters() if "on_qs" in p[0]
        ]
        beta_params = [
            p[1] for p in self.model.named_parameters() if "on_beta" in p[0]
        ]

        pi_s_params = [
            p[1] for p in self.model.named_parameters() if "on_pi_s" in p[0]
        ]
        pi_l_params = [
            p[1] for p in self.model.named_parameters() if "on_pi_l" in p[0]
        ]
        pi_tau_s_params = [
            p[1] for p in self.model.named_parameters()
            if "on_pi_tau_s" in p[0]
        ]
        pi_tau_l_params = [
            p[1] for p in self.model.named_parameters()
            if "on_pi_tau_l" in p[0]
        ]

        self.parameters_group_a = net_parameters + vl_params + ql_params + vs_params + qs_params + beta_params
        self.parameters_group_b = pi_s_params + pi_l_params + pi_tau_s_params + pi_tau_l_params

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer_vl = BehavioralDistAgent.set_optimizer(
            net_parameters + vl_params, args.lr_vl)
        self.scheduler_vl = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_vl, self.decay)

        self.optimizer_beta = BehavioralDistAgent.set_optimizer(
            net_parameters + beta_params, args.lr_beta)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        self.optimizer_vs = BehavioralDistAgent.set_optimizer(
            net_parameters + vs_params, args.lr_vs)
        self.scheduler_vs = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_vs, self.decay)

        self.optimizer_qs = BehavioralDistAgent.set_optimizer(
            net_parameters + qs_params, args.lr_qs)
        self.scheduler_qs = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_qs, self.decay)

        self.optimizer_ql = BehavioralDistAgent.set_optimizer(
            net_parameters + ql_params, args.lr_ql)
        self.scheduler_ql = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_ql, self.decay)

        self.optimizer_pi_l = BehavioralDistAgent.set_optimizer(
            pi_l_params, args.lr_pi_l)
        self.scheduler_pi_l = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_l, self.decay)

        self.optimizer_pi_s = BehavioralDistAgent.set_optimizer(
            pi_s_params, args.lr_pi_s)
        self.scheduler_pi_s = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_s, self.decay)

        self.optimizer_pi_l_tau = BehavioralDistAgent.set_optimizer(
            pi_tau_l_params, args.lr_pi_tau_l)
        self.scheduler_pi_l_tau = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_l_tau, self.decay)

        self.optimizer_pi_s_tau = BehavioralDistAgent.set_optimizer(
            pi_tau_s_params, args.lr_pi_tau_s)
        self.scheduler_pi_s_tau = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_s_tau, self.decay)

        actions = torch.FloatTensor(consts.hotvec_matrix) / (3**(0.5))
        actions = Variable(actions, requires_grad=False).cuda()

        self.actions_matrix = actions.unsqueeze(0)
        self.reverse_excitation_index = consts.hotvec_inv

        self.short_bins = consts.short_bins[
            args.game][:-1] / self.meta['avg_score']
        # the long bins are already normalized
        self.long_bins = consts.long_bins[args.game][:-1]

        self.short_bins_torch = Variable(torch.from_numpy(
            consts.short_bins[args.game] / self.meta['avg_score']),
                                         requires_grad=False).cuda()
        self.long_bins_torch = Variable(torch.from_numpy(
            consts.long_bins[args.game]),
                                        requires_grad=False).cuda()

        self.batch_range = np.arange(self.batch)

        self.zero = Variable(torch.zeros(1))
示例#2
0
    def __init__(self, load_dataset=True):

        print("Detached Agent")
        super(DetachedAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset_by_episodes(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

        self.norm = 2
        self.loss_v_beta = torch.nn.MSELoss(size_average=True, reduce=True)
        self.loss_q_pi = torch.nn.MSELoss(size_average=True, reduce=True)

        self.loss_q_beta = torch.nn.MSELoss(size_average=True, reduce=True)

        self.histogram = torch.from_numpy(
            self.meta['histogram']).float().cuda()
        # weights = self.histogram.max() / self.histogram
        # weights = torch.clamp(weights, 0, 10)
        # weights = 1 - self.histogram

        if self.balance:
            self.loss_beta = torch.nn.CrossEntropyLoss(size_average=True)
        else:
            weights = self.histogram + args.balance_epsilone
            weights = weights.max() / weights
            self.loss_beta = torch.nn.CrossEntropyLoss(size_average=True,
                                                       weight=weights)

        self.loss_pi = torch.nn.CrossEntropyLoss(reduce=False)

        # actor critic setting
        self.beta_net = DPiN().cuda()
        self.beta_target = DPiN().cuda()

        self.pi_net = DPiN().cuda()
        self.pi_target = DPiN().cuda()

        self.vb_net = DVN().cuda()
        self.vb_target = DVN().cuda()

        self.qb_net = DQN().cuda()
        self.qb_target = DQN().cuda()

        self.q_net = DQN().cuda()
        self.q_target = DQN().cuda()

        # configure learning

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER

        self.optimizer_q_pi = DetachedAgent.set_optimizer(
            self.q_net.parameters(), 0.0001)  # 0.0002
        self.scheduler_q_pi = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q_pi, self.decay)

        self.optimizer_q_beta = DetachedAgent.set_optimizer(
            self.qb_net.parameters(), 0.001)  # 0.0002 0.0001
        self.scheduler_q_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q_beta, self.decay)

        self.optimizer_pi = DetachedAgent.set_optimizer(
            self.pi_net.parameters(), 0.0002)
        self.scheduler_pi = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi, self.decay)

        self.optimizer_v_beta = DetachedAgent.set_optimizer(
            self.vb_net.parameters(), 0.001)  # 0.0001
        self.scheduler_v_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v_beta, self.decay)

        self.optimizer_beta = DetachedAgent.set_optimizer(
            self.beta_net.parameters(), 0.01)  # 0.0008 0.0006
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        actions = torch.LongTensor(consts.hotvec_matrix).cuda()
        self.actions_matrix = Variable(actions.unsqueeze(0),
                                       requires_grad=False)

        self.batch_actions_matrix = self.actions_matrix.repeat(
            self.batch, 1, 1)

        self.mask_beta = Variable(torch.FloatTensor(
            consts.behavioral_mask[args.game]),
                                  requires_grad=False).cuda()
        self.mask_beta[self.mask_beta == 0] = -float("Inf")
        self.mask_beta[self.mask_beta == 1] = 0
        self.mask_beta_batch = self.mask_beta.repeat(self.batch, 1)

        self.mask_q = Variable(torch.FloatTensor(
            consts.behavioral_mask[args.game]),
                               requires_grad=False).cuda()
        self.mask_q_batch = self.mask_q.repeat(self.batch, 1)

        self.zero = Variable(torch.zeros(1))

        self.mc = True
class BehavioralDistAgent(Agent):
    def __init__(self, load_dataset=True):

        super(BehavioralDistAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset,
                batch_sampler=self.val_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

            self.episodic_loader = torch.utils.data.DataLoader(
                self.full_dataset,
                sampler=self.episodic_sampler,
                batch_size=self.batch,
                num_workers=args.cpu_workers)

        if not self.wasserstein:
            self.loss_fn_vs = torch.nn.CrossEntropyLoss(size_average=True)
            self.loss_fn_qs = torch.nn.CrossEntropyLoss(size_average=True)
            self.loss_fn_vl = torch.nn.CrossEntropyLoss(size_average=True)
            self.loss_fn_ql = torch.nn.CrossEntropyLoss(size_average=True)
        else:
            self.loss_fn_vs = wasserstein_metric(support=args.atoms_short, n=1)
            self.loss_fn_qs = wasserstein_metric(support=args.atoms_short, n=1)
            self.loss_fn_vl = wasserstein_metric(support=args.atoms_long, n=1)
            self.loss_fn_ql = wasserstein_metric(support=args.atoms_long, n=1)

        self.histogram = torch.from_numpy(self.meta['histogram']).float()
        m = self.histogram.max()
        self.histogram = m / self.histogram
        self.histogram = torch.clamp(self.histogram, 0, 10).cuda()

        self.loss_fn_beta = torch.nn.CrossEntropyLoss(size_average=True,
                                                      weight=self.histogram)
        self.loss_fn_pi_s = torch.nn.CrossEntropyLoss(reduce=False,
                                                      size_average=True)
        self.loss_fn_pi_l = torch.nn.CrossEntropyLoss(reduce=False,
                                                      size_average=True)
        self.loss_fn_pi_s_tau = torch.nn.CrossEntropyLoss(reduce=False,
                                                          size_average=True)
        self.loss_fn_pi_l_tau = torch.nn.CrossEntropyLoss(reduce=False,
                                                          size_average=True)

        # alpha weighted sum

        self.alpha_b = 1  # 1 / 0.7

        self.alpha_vs = 1  # 1 / 0.02
        self.alpha_qs = 1

        self.alpha_vl = 1  # 1 / 0.02
        self.alpha_ql = 1

        self.alpha_pi_s = 1  # 1 / 0.02
        self.alpha_pi_l = 1

        self.alpha_pi_s_tau = 1  # 1 / 0.02
        self.alpha_pi_l_tau = 1

        self.model = BehavioralDistNet()
        self.model.cuda()

        # configure learning

        net_parameters = [
            p[1] for p in self.model.named_parameters() if "rn_" in p[0]
        ]
        vl_params = [
            p[1] for p in self.model.named_parameters() if "on_vl" in p[0]
        ]
        ql_params = [
            p[1] for p in self.model.named_parameters() if "on_ql" in p[0]
        ]
        vs_params = [
            p[1] for p in self.model.named_parameters() if "on_vs" in p[0]
        ]
        qs_params = [
            p[1] for p in self.model.named_parameters() if "on_qs" in p[0]
        ]
        beta_params = [
            p[1] for p in self.model.named_parameters() if "on_beta" in p[0]
        ]

        pi_s_params = [
            p[1] for p in self.model.named_parameters() if "on_pi_s" in p[0]
        ]
        pi_l_params = [
            p[1] for p in self.model.named_parameters() if "on_pi_l" in p[0]
        ]
        pi_tau_s_params = [
            p[1] for p in self.model.named_parameters()
            if "on_pi_tau_s" in p[0]
        ]
        pi_tau_l_params = [
            p[1] for p in self.model.named_parameters()
            if "on_pi_tau_l" in p[0]
        ]

        self.parameters_group_a = net_parameters + vl_params + ql_params + vs_params + qs_params + beta_params
        self.parameters_group_b = pi_s_params + pi_l_params + pi_tau_s_params + pi_tau_l_params

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer_vl = BehavioralDistAgent.set_optimizer(
            net_parameters + vl_params, args.lr_vl)
        self.scheduler_vl = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_vl, self.decay)

        self.optimizer_beta = BehavioralDistAgent.set_optimizer(
            net_parameters + beta_params, args.lr_beta)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        self.optimizer_vs = BehavioralDistAgent.set_optimizer(
            net_parameters + vs_params, args.lr_vs)
        self.scheduler_vs = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_vs, self.decay)

        self.optimizer_qs = BehavioralDistAgent.set_optimizer(
            net_parameters + qs_params, args.lr_qs)
        self.scheduler_qs = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_qs, self.decay)

        self.optimizer_ql = BehavioralDistAgent.set_optimizer(
            net_parameters + ql_params, args.lr_ql)
        self.scheduler_ql = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_ql, self.decay)

        self.optimizer_pi_l = BehavioralDistAgent.set_optimizer(
            pi_l_params, args.lr_pi_l)
        self.scheduler_pi_l = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_l, self.decay)

        self.optimizer_pi_s = BehavioralDistAgent.set_optimizer(
            pi_s_params, args.lr_pi_s)
        self.scheduler_pi_s = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_s, self.decay)

        self.optimizer_pi_l_tau = BehavioralDistAgent.set_optimizer(
            pi_tau_l_params, args.lr_pi_tau_l)
        self.scheduler_pi_l_tau = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_l_tau, self.decay)

        self.optimizer_pi_s_tau = BehavioralDistAgent.set_optimizer(
            pi_tau_s_params, args.lr_pi_tau_s)
        self.scheduler_pi_s_tau = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_s_tau, self.decay)

        actions = torch.FloatTensor(consts.hotvec_matrix) / (3**(0.5))
        actions = Variable(actions, requires_grad=False).cuda()

        self.actions_matrix = actions.unsqueeze(0)
        self.reverse_excitation_index = consts.hotvec_inv

        self.short_bins = consts.short_bins[
            args.game][:-1] / self.meta['avg_score']
        # the long bins are already normalized
        self.long_bins = consts.long_bins[args.game][:-1]

        self.short_bins_torch = Variable(torch.from_numpy(
            consts.short_bins[args.game] / self.meta['avg_score']),
                                         requires_grad=False).cuda()
        self.long_bins_torch = Variable(torch.from_numpy(
            consts.long_bins[args.game]),
                                        requires_grad=False).cuda()

        self.batch_range = np.arange(self.batch)

        self.zero = Variable(torch.zeros(1))

    def flip_grad(self, parameters):
        for p in parameters:
            p.requires_grad = not p.requires_grad

    @staticmethod
    def individual_loss_fn_l2(argument):
        return abs(argument.data.cpu().numpy())**2

    @staticmethod
    def individual_loss_fn_l1(argument):
        return abs(argument.data.cpu().numpy())

    def save_checkpoint(self, path, aux=None):

        cpu_state = self.model.state_dict()
        for k in cpu_state:
            cpu_state[k] = cpu_state[k].cpu()

        state = {
            'state_dict': self.model.state_dict(),
            'state_dict_cpu': cpu_state,
            'optimizer_vl_dict': self.optimizer_vl.state_dict(),
            'optimizer_beta_dict': self.optimizer_beta.state_dict(),
            'optimizer_vs_dict': self.optimizer_vs.state_dict(),
            'optimizer_ql_dict': self.optimizer_ql.state_dict(),
            'optimizer_qs_dict': self.optimizer_qs.state_dict(),
            'optimizer_pi_s_dict': self.optimizer_pi_s.state_dict(),
            'optimizer_pi_l_dict': self.optimizer_pi_l.state_dict(),
            'optimizer_pi_s_tau_dict': self.optimizer_pi_s_tau.state_dict(),
            'optimizer_pi_l_tau_dict': self.optimizer_pi_l_tau.state_dict(),
            'aux': aux
        }

        torch.save(state, path)

    def one_hot(self, y, nb_digits):
        batch_size = y.shape[0]
        y_onehot = torch.zeros(batch_size, nb_digits)
        return y_onehot.scatter_(1, y.unsqueeze(1), 1)

    def load_checkpoint(self, path):

        if self.cuda:
            state = torch.load(path)
            self.model.load_state_dict(state['state_dict'])
        else:
            state = torch.load(path,
                               map_location=lambda storage, location: storage)
            self.model.load_state_dict(state['state_dict_cpu'])
        self.optimizer_vl.load_state_dict(state['optimizer_vl_dict'])
        self.optimizer_beta.load_state_dict(state['optimizer_beta_dict'])
        self.optimizer_vs.load_state_dict(state['optimizer_vs_dict'])
        self.optimizer_ql.load_state_dict(state['optimizer_ql_dict'])
        self.optimizer_qs.load_state_dict(state['optimizer_qs_dict'])
        self.optimizer_pi_s_tau.load_state_dict(
            state['optimizer_pi_s_tau_dict'])
        self.optimizer_pi_l_tau.load_state_dict(
            state['optimizer_pi_l_tau_dict'])
        self.optimizer_pi_s.load_state_dict(state['optimizer_pi_s_dict'])
        self.optimizer_pi_l.load_state_dict(state['optimizer_pi_l_dict'])

        return state['aux']

    def resume(self, model_path):

        aux = self.load_checkpoint(model_path)
        # self.update_target()
        return aux

    def update_target(self):
        self.target.load_state_dict(self.model.state_dict())

    def dummy_episodic_evaluator(self):
        while True:
            yield {
                'q_diff': torch.zeros(100),
                'a_agent': torch.zeros(100, self.action_space),
                'a_player': torch.zeros(100).long()
            }

    def _episodic_evaluator(self):
        pass

    def get_weighted_loss(self, x, bins):
        xd = x.data
        inds = xd.cumsum(1) <= self.quantile
        inds = inds.sum(1).long()
        return bins[inds]

    def learn(self, n_interval, n_tot):

        self.model.train()
        # self.target.eval()
        results = {
            'n': [],
            'loss_vs': [],
            'loss_b': [],
            'loss_vl': [],
            'loss_qs': [],
            'loss_ql': [],
            'loss_pi_s': [],
            'loss_pi_l': [],
            'loss_pi_s_tau': [],
            'loss_pi_l_tau': []
        }

        self.flip_grad(self.parameters_group_b)
        train_net = True

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda(), requires_grad=False)

            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)

            rl = np.digitize(sample['score'].numpy(),
                             self.long_bins,
                             right=True)
            rs = np.digitize(sample['f'].numpy(), self.short_bins, right=True)

            Rl = Variable(sample['score'].cuda(), requires_grad=False)
            Rs = Variable(sample['f'].cuda(), requires_grad=False)

            if self.wasserstein:
                rl = Variable(self.one_hot(torch.LongTensor(rl),
                                           self.atoms_long).cuda(),
                              requires_grad=False)
                rs = Variable(self.one_hot(torch.LongTensor(rs),
                                           self.atoms_short).cuda(),
                              requires_grad=False)
            else:
                rl = Variable(torch.from_numpy(rl).cuda(), requires_grad=False)
                rs = Variable(torch.from_numpy(rs).cuda(), requires_grad=False)

            vs, vl, beta, qs, ql, phi, pi_s, pi_l, pi_s_tau, pi_l_tau = self.model(
                s, a)

            # policy learning

            if self.alpha_vs and train_net:
                loss_vs = self.alpha_vs * self.loss_fn_vs(vs, rs)
                self.optimizer_vs.zero_grad()
                loss_vs.backward(retain_graph=True)
                self.optimizer_vs.step()
            else:
                loss_vs = self.zero

            if self.alpha_vl and train_net:
                loss_vl = self.alpha_vl * self.loss_fn_vl(vl, rl)
                self.optimizer_vl.zero_grad()
                loss_vl.backward(retain_graph=True)
                self.optimizer_vl.step()
            else:
                loss_vl = self.zero

            if self.alpha_b and train_net:
                loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)
                self.optimizer_beta.zero_grad()
                loss_b.backward(retain_graph=True)
                self.optimizer_beta.step()
            else:
                loss_b = self.zero

            if self.alpha_qs and train_net:
                loss_qs = self.alpha_qs * self.loss_fn_qs(qs, rs)
                self.optimizer_qs.zero_grad()
                loss_qs.backward(retain_graph=True)
                self.optimizer_qs.step()
            else:
                loss_qs = self.zero

            if self.alpha_ql and train_net:
                loss_ql = self.alpha_ql * self.loss_fn_ql(ql, rl)
                self.optimizer_ql.zero_grad()
                loss_ql.backward(retain_graph=True)
                self.optimizer_ql.step()
            else:
                loss_ql = self.zero

            a_index_np = sample['a_index'].numpy()
            self.batch_range = np.arange(self.batch)

            beta_sfm = F.softmax(beta, 1)
            pi_s_sfm = F.softmax(pi_s, 1)
            pi_l_sfm = F.softmax(pi_l, 1)
            pi_s_tau_sfm = F.softmax(pi_s, 1)
            pi_l_tau_sfm = F.softmax(pi_l, 1)

            beta_fix = Variable(beta_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_fix = Variable(pi_s_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_l_fix = Variable(pi_l_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_tau_fix = Variable(pi_s_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)
            pi_l_tau_fix = Variable(pi_l_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)

            if self.alpha_pi_s and not train_net:
                loss_pi_s = self.alpha_pi_s * self.loss_fn_pi_s(pi_s, a_index)
                loss_pi_s = (loss_pi_s * Rs *
                             self.off_factor(pi_s_fix, beta_fix)).mean()
                self.optimizer_pi_s.zero_grad()
                loss_pi_s.backward(retain_graph=True)
                self.optimizer_pi_s.step()
            else:
                loss_pi_s = self.zero

            if self.alpha_pi_l and not train_net:
                loss_pi_l = self.alpha_pi_l * self.loss_fn_pi_l(pi_l, a_index)
                loss_pi_l = (loss_pi_l * Rl *
                             self.off_factor(pi_l_fix, beta_fix)).mean()
                self.optimizer_pi_l.zero_grad()
                loss_pi_l.backward(retain_graph=True)
                self.optimizer_pi_l.step()
            else:
                loss_pi_l = self.zero

            if self.alpha_pi_s_tau and not train_net:
                loss_pi_s_tau = self.alpha_pi_s_tau * self.loss_fn_pi_s_tau(
                    pi_s_tau, a_index)
                w = self.get_weighted_loss(F.softmax(qs, 1),
                                           self.short_bins_torch)
                loss_pi_s_tau = (
                    loss_pi_s_tau * w *
                    self.off_factor(pi_s_tau_fix, beta_fix)).mean()
                self.optimizer_pi_s_tau.zero_grad()
                loss_pi_s_tau.backward(retain_graph=True)
                self.optimizer_pi_s_tau.step()
            else:
                loss_pi_s_tau = self.zero

            if self.alpha_pi_l_tau and not train_net:
                loss_pi_l_tau = self.alpha_pi_l_tau * self.loss_fn_pi_l_tau(
                    pi_l_tau, a_index)
                w = self.get_weighted_loss(F.softmax(ql, 1),
                                           self.long_bins_torch)
                loss_pi_l_tau = (
                    loss_pi_l_tau * w *
                    self.off_factor(pi_l_tau_fix, beta_fix)).mean()
                self.optimizer_pi_l_tau.zero_grad()
                loss_pi_l_tau.backward()
                self.optimizer_pi_l_tau.step()
            else:
                loss_pi_l_tau = self.zero

            # add results
            results['loss_vs'].append(loss_vs.data.cpu().numpy()[0])
            results['loss_vl'].append(loss_vl.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_qs'].append(loss_qs.data.cpu().numpy()[0])
            results['loss_ql'].append(loss_ql.data.cpu().numpy()[0])
            results['loss_pi_s'].append(loss_pi_s.data.cpu().numpy()[0])
            results['loss_pi_l'].append(loss_pi_l.data.cpu().numpy()[0])
            results['loss_pi_s_tau'].append(
                loss_pi_s_tau.data.cpu().numpy()[0])
            results['loss_pi_l_tau'].append(
                loss_pi_l_tau.data.cpu().numpy()[0])
            results['n'].append(n)

            # if not n % self.update_target_interval:
            #     # self.update_target()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (
                    n + 1
            ) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            # update a global n_step parameter

            if not (n + 1) % self.update_n_steps_interval:
                # self.train_dataset.update_n_step(n + 1)
                d = np.divmod(n + 1, self.update_n_steps_interval)[0]
                if d % 10 == 1:
                    self.flip_grad(self.parameters_group_b +
                                   self.parameters_group_a)
                    train_net = not train_net
                if d % 10 == 2:
                    self.flip_grad(self.parameters_group_b +
                                   self.parameters_group_a)
                    train_net = not train_net

                    self.scheduler_pi_s.step()
                    self.scheduler_pi_l.step()
                    self.scheduler_pi_s_tau.step()
                    self.scheduler_pi_l_tau.step()
                else:
                    self.scheduler_vs.step()
                    self.scheduler_beta.step()
                    self.scheduler_vl.step()
                    self.scheduler_qs.step()
                    self.scheduler_ql.step()

            if not (n + 1) % n_interval:
                yield results
                self.model.train()
                # self.target.eval()
                results = {key: [] for key in results}

    def off_factor(self, pi, beta):
        return torch.clamp(pi / beta, 0, 1)

    def test(self, n_interval, n_tot):

        self.model.eval()
        # self.target.eval()

        results = {
            'n': [],
            'loss_vs': [],
            'loss_b': [],
            'loss_vl': [],
            'loss_qs': [],
            'loss_ql': [],
            'act_diff': [],
            'a_agent': [],
            'a_player': [],
            'loss_pi_s': [],
            'loss_pi_l': [],
            'loss_pi_s_tau': [],
            'loss_pi_l_tau': []
        }

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda().unsqueeze(1), requires_grad=False)

            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)

            rl = np.digitize(sample['score'].numpy(),
                             self.long_bins,
                             right=True)
            rs = np.digitize(sample['f'].numpy(), self.short_bins, right=True)

            Rl = Variable(sample['score'].cuda(), requires_grad=False)
            Rs = Variable(sample['f'].cuda(), requires_grad=False)

            if self.wasserstein:
                rl = Variable(self.one_hot(torch.LongTensor(rl),
                                           self.atoms_long).cuda(),
                              requires_grad=False)
                rs = Variable(self.one_hot(torch.LongTensor(rs),
                                           self.atoms_short).cuda(),
                              requires_grad=False)
            else:
                rl = Variable(torch.from_numpy(rl).cuda(), requires_grad=False)
                rs = Variable(torch.from_numpy(rs).cuda(), requires_grad=False)

            vs, vl, beta, qs, ql, phi, pi_s, pi_l, pi_s_tau, pi_l_tau = self.model(
                s, a)

            qs = qs.squeeze(1)
            ql = ql.squeeze(1)

            # policy learning

            loss_vs = self.alpha_vs * self.loss_fn_vs(vs, rs)
            loss_vl = self.alpha_vl * self.loss_fn_vl(vl, rl)
            loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)
            loss_qs = self.alpha_qs * self.loss_fn_qs(qs, rs)
            loss_ql = self.alpha_ql * self.loss_fn_ql(ql, rl)

            a_index_np = sample['a_index'].numpy()
            self.batch_range = np.arange(self.batch)

            beta_sfm = F.softmax(beta, 1)
            pi_s_sfm = F.softmax(pi_s, 1)
            pi_l_sfm = F.softmax(pi_l, 1)
            pi_s_tau_sfm = F.softmax(pi_s, 1)
            pi_l_tau_sfm = F.softmax(pi_l, 1)

            beta_fix = Variable(beta_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_fix = Variable(pi_s_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_l_fix = Variable(pi_l_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_tau_fix = Variable(pi_s_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)
            pi_l_tau_fix = Variable(pi_l_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)

            loss_pi_s = self.alpha_pi_s * self.loss_fn_pi_s(pi_s, a_index)
            loss_pi_s = (loss_pi_s * Rs *
                         self.off_factor(pi_s_fix, beta_fix)).mean()

            loss_pi_l = self.alpha_pi_l * self.loss_fn_pi_l(pi_l, a_index)
            loss_pi_l = (loss_pi_l * Rl *
                         self.off_factor(pi_l_fix, beta_fix)).mean()

            loss_pi_s_tau = self.alpha_pi_s_tau * self.loss_fn_pi_s_tau(
                pi_s_tau, a_index)
            w = self.get_weighted_loss(F.softmax(qs, 1), self.short_bins_torch)
            loss_pi_s_tau = (loss_pi_s_tau * w *
                             self.off_factor(pi_s_tau_fix, beta_fix)).mean()

            loss_pi_l_tau = self.alpha_pi_l_tau * self.loss_fn_pi_l_tau(
                pi_l_tau, a_index)
            w = self.get_weighted_loss(F.softmax(ql, 1), self.long_bins_torch)
            loss_pi_l_tau = (loss_pi_l_tau * w *
                             self.off_factor(pi_l_tau_fix, beta_fix)).mean()

            # collect actions statistics
            a_index_np = a_index.data.cpu().numpy()

            _, beta_index = beta.data.cpu().max(1)
            beta_index = beta_index.numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            # add results
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)
            results['loss_vs'].append(loss_vs.data.cpu().numpy()[0])
            results['loss_vl'].append(loss_vl.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_qs'].append(loss_qs.data.cpu().numpy()[0])
            results['loss_ql'].append(loss_ql.data.cpu().numpy()[0])
            results['loss_pi_s'].append(loss_pi_s.data.cpu().numpy()[0])
            results['loss_pi_l'].append(loss_pi_l.data.cpu().numpy()[0])
            results['loss_pi_s_tau'].append(
                loss_pi_s_tau.data.cpu().numpy()[0])
            results['loss_pi_l_tau'].append(
                loss_pi_l_tau.data.cpu().numpy()[0])
            results['n'].append(n)

            if not (n + 1) % n_interval:
                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.model.eval()
                # self.target.eval()
                results = {key: [] for key in results}

    def play_stochastic(self, n_tot):
        raise NotImplementedError
        # self.model.eval()
        # env = Env()
        # render = args.render
        #
        # n_human = 60
        # humans_trajectories = iter(self.data)
        #
        # for i in range(n_tot):
        #
        #     env.reset()
        #
        #     observation = next(humans_trajectories)
        #     print("Observation %s" % observation)
        #     trajectory = self.data[observation]
        #
        #     j = 0
        #
        #     while not env.t:
        #
        #         if j < n_human:
        #             a = trajectory[j, self.meta['action']]
        #
        #         else:
        #
        #             if self.cuda:
        #                 s = Variable(env.s.cuda(), requires_grad=False)
        #             else:
        #                 s = Variable(env.s, requires_grad=False)
        #             _, q, _, _, _, _ = self.model(s, self.actions_matrix)
        #
        #             q = q.squeeze(2)
        #
        #             q = q.data.cpu().numpy()
        #             a = np.argmax(q)
        #
        #         env.step(a)
        #
        #         j += 1
        #
        #     yield {'o': env.s.cpu().numpy(),
        #            'score': env.score}

    def play_episode(self, n_tot):

        self.model.eval()
        env = Env()

        n_human = 120
        humans_trajectories = iter(self.data)
        softmax = torch.nn.Softmax()

        # mask = torch.FloatTensor(consts.actions_mask[args.game])
        # mask = Variable(mask.cuda(), requires_grad=False)

        vsx = torch.FloatTensor(consts.short_bins[args.game])
        vlx = torch.FloatTensor(consts.long_bins[args.game])

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]
            choices = np.arange(self.global_action_space, dtype=np.int)

            j = 0

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)
                vs, vl, beta, qs, ql, phi, pi_s, pi_l, pi_s_tau, pi_l_tau = self.model(
                    s, self.actions_matrix)
                beta = beta.squeeze(0)
                pi_l = pi_l.squeeze(0)
                pi_s = pi_s.squeeze(0)
                pi_l_tau = pi_l_tau.squeeze(0)
                pi_s_tau = pi_s_tau.squeeze(0)

                temp = 1

                # consider only 3 most frequent actions
                beta_np = beta.data.cpu().numpy()
                indices = np.argsort(beta_np)

                maskb = Variable(torch.FloatTensor(
                    [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
                                 requires_grad=False).cuda()
                # maskb = Variable(torch.FloatTensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
                #                  requires_grad=False).cuda()

                # pi = maskb * (beta / beta.max())

                pi = beta
                self.greedy = False

                beta_prob = pi

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    eps = np.random.rand()
                    # a = np.random.choice(choices)
                    if self.greedy and eps > 0.1:
                        a = pi.data.cpu().numpy()
                        a = np.argmax(a)
                    else:
                        a = softmax(pi / temp).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                env.step(a)

                vs = softmax(vs)
                vl = softmax(vl)
                vs = torch.sum(vsx * vs.data.cpu())
                vl = torch.sum(vlx * vl.data.cpu())

                yield {
                    'o': env.s.cpu().numpy(),
                    'vs': np.array([vs]),
                    'vl': np.array([vl]),
                    's': phi.data.cpu().numpy(),
                    'score': env.score,
                    'beta': beta_prob.data.cpu().numpy(),
                    'phi': phi.squeeze(0).data.cpu().numpy(),
                    'qs': qs.squeeze(0).data.cpu().numpy(),
                    'ql': ql.squeeze(0).data.cpu().numpy(),
                }

                j += 1

        raise StopIteration

    def policy(self, vs, vl, beta, qs, ql):
        pass
示例#4
0
    def __init__(self, load_dataset=True):

        super(ACDQNLSTMAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta, self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset, train=True)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset, train=False)

            self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_sampler=self.train_sampler,
                                                            num_workers=args.cpu_workers, pin_memory=True, drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_sampler=self.test_sampler,
                                                           num_workers=args.cpu_workers, pin_memory=True, drop_last=False)

        self.loss_v_beta = torch.nn.L1Loss(size_average=True, reduce=True)
        self.loss_q_beta = torch.nn.L1Loss(size_average=True, reduce=True)

        self.loss_v_pi = torch.nn.L1Loss(size_average=True, reduce=True)
        self.loss_q_pi = torch.nn.L1Loss(size_average=True, reduce=True)

        self.loss_p = torch.nn.L1Loss(size_average=True, reduce=True)

        self.histogram = torch.from_numpy(self.meta['histogram']).float()
        weights = self.histogram.max() / self.histogram
        weights = torch.clamp(weights, 0, 10).cuda()

        self.loss_beta = torch.nn.CrossEntropyLoss(size_average=True)
        self.loss_pi = torch.nn.CrossEntropyLoss(reduce=False)

        # actor critic setting

        self.model_b_single = ACDQNLSTM().cuda()
        self.model_single = ACDQNLSTM().cuda()
        self.target_single = ACDQNLSTM().cuda()

        if self.parallel:
            self.model_b = torch.nn.DataParallel(self.model_b_single)
            self.model = torch.nn.DataParallel(self.model_single)
            self.target = torch.nn.DataParallel(self.target_single)
        else:
            self.model_b = self.model_b_single
            self.model = self.model_single
            self.target = self.target_single

        self.target_single.reset_target()
        # configure learning

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER


        self.optimizer_q_pi = ACDQNLSTMAgent.set_optimizer(self.model.parameters(), 0.0002)
        self.scheduler_q_pi = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_q_pi, self.decay)

        self.optimizer_pi = ACDQNLSTMAgent.set_optimizer(self.model.parameters(), 0.0002)
        self.scheduler_pi = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_pi, self.decay)

        self.optimizer_q_beta = ACDQNLSTMAgent.set_optimizer(self.model_b.parameters(), 0.0002)
        self.scheduler_q_beta = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_q_beta, self.decay)

        self.optimizer_beta = ACDQNLSTMAgent.set_optimizer(self.model_b.parameters(), 0.0008)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_beta, self.decay)

        actions = torch.LongTensor(consts.hotvec_matrix).cuda()
        self.actions_matrix = Variable(actions.unsqueeze(0), requires_grad=False)

        self.batch_actions_matrix = self.actions_matrix.repeat(self.batch, 1, 1)

        self.batch_range = np.arange(self.batch)
        self.zero = Variable(torch.zeros(1))
        self.a_post_mat = Variable(torch.from_numpy(consts.a_post_mat).long(), requires_grad=False).cuda()
        self.a_post_mat = self.a_post_mat.unsqueeze(0).repeat(self.batch, 1, 1)
示例#5
0
class DetachedAgent(Agent):
    def __init__(self, load_dataset=True):

        print("Detached Agent")
        super(DetachedAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset_by_episodes(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

        self.norm = 2
        self.loss_v_beta = torch.nn.MSELoss(size_average=True, reduce=True)
        self.loss_q_pi = torch.nn.MSELoss(size_average=True, reduce=True)

        self.loss_q_beta = torch.nn.MSELoss(size_average=True, reduce=True)

        self.histogram = torch.from_numpy(
            self.meta['histogram']).float().cuda()
        # weights = self.histogram.max() / self.histogram
        # weights = torch.clamp(weights, 0, 10)
        # weights = 1 - self.histogram

        if self.balance:
            self.loss_beta = torch.nn.CrossEntropyLoss(size_average=True)
        else:
            weights = self.histogram + args.balance_epsilone
            weights = weights.max() / weights
            self.loss_beta = torch.nn.CrossEntropyLoss(size_average=True,
                                                       weight=weights)

        self.loss_pi = torch.nn.CrossEntropyLoss(reduce=False)

        # actor critic setting
        self.beta_net = DPiN().cuda()
        self.beta_target = DPiN().cuda()

        self.pi_net = DPiN().cuda()
        self.pi_target = DPiN().cuda()

        self.vb_net = DVN().cuda()
        self.vb_target = DVN().cuda()

        self.qb_net = DQN().cuda()
        self.qb_target = DQN().cuda()

        self.q_net = DQN().cuda()
        self.q_target = DQN().cuda()

        # configure learning

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER

        self.optimizer_q_pi = DetachedAgent.set_optimizer(
            self.q_net.parameters(), 0.0001)  # 0.0002
        self.scheduler_q_pi = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q_pi, self.decay)

        self.optimizer_q_beta = DetachedAgent.set_optimizer(
            self.qb_net.parameters(), 0.001)  # 0.0002 0.0001
        self.scheduler_q_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q_beta, self.decay)

        self.optimizer_pi = DetachedAgent.set_optimizer(
            self.pi_net.parameters(), 0.0002)
        self.scheduler_pi = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi, self.decay)

        self.optimizer_v_beta = DetachedAgent.set_optimizer(
            self.vb_net.parameters(), 0.001)  # 0.0001
        self.scheduler_v_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v_beta, self.decay)

        self.optimizer_beta = DetachedAgent.set_optimizer(
            self.beta_net.parameters(), 0.01)  # 0.0008 0.0006
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        actions = torch.LongTensor(consts.hotvec_matrix).cuda()
        self.actions_matrix = Variable(actions.unsqueeze(0),
                                       requires_grad=False)

        self.batch_actions_matrix = self.actions_matrix.repeat(
            self.batch, 1, 1)

        self.mask_beta = Variable(torch.FloatTensor(
            consts.behavioral_mask[args.game]),
                                  requires_grad=False).cuda()
        self.mask_beta[self.mask_beta == 0] = -float("Inf")
        self.mask_beta[self.mask_beta == 1] = 0
        self.mask_beta_batch = self.mask_beta.repeat(self.batch, 1)

        self.mask_q = Variable(torch.FloatTensor(
            consts.behavioral_mask[args.game]),
                               requires_grad=False).cuda()
        self.mask_q_batch = self.mask_q.repeat(self.batch, 1)

        self.zero = Variable(torch.zeros(1))

        self.mc = True

    def save_checkpoint(self, path, aux=None):

        state = {
            'beta_net': self.beta_net.state_dict(),
            'beta_target': self.beta_target.state_dict(),
            'pi_net': self.pi_net.state_dict(),
            'pi_target': self.pi_target.state_dict(),
            'vb_net': self.vb_net.state_dict(),
            'vb_target': self.vb_target.state_dict(),
            'q_net': self.q_net.state_dict(),
            'q_target': self.q_target.state_dict(),
            'qb_net': self.qb_net.state_dict(),
            'qb_target': self.qb_target.state_dict(),
            'optimizer_q_pi': self.optimizer_q_pi.state_dict(),
            'optimizer_pi': self.optimizer_pi.state_dict(),
            'optimizer_v_beta': self.optimizer_v_beta.state_dict(),
            'optimizer_q_beta': self.optimizer_q_beta.state_dict(),
            'optimizer_beta': self.optimizer_beta.state_dict(),
            'aux': aux
        }

        torch.save(state, path)

    def load_checkpoint(self, path):

        state = torch.load(path)
        self.beta_net.load_state_dict(state['beta_net'])
        self.beta_target.load_state_dict(state['beta_target'])

        self.pi_net.load_state_dict(state['pi_net'])
        self.pi_target.load_state_dict(state['pi_target'])

        self.vb_net.load_state_dict(state['vb_net'])
        self.vb_target.load_state_dict(state['vb_target'])

        self.q_net.load_state_dict(state['q_net'])
        self.q_target.load_state_dict(state['q_target'])

        self.qb_net.load_state_dict(state['qb_net'])
        self.qb_target.load_state_dict(state['qb_target'])

        self.optimizer_q_pi.load_state_dict(state['optimizer_q_pi'])
        self.optimizer_pi.load_state_dict(state['optimizer_pi'])
        self.optimizer_v_beta.load_state_dict(state['optimizer_v_beta'])
        self.optimizer_q_beta.load_state_dict(state['optimizer_q_beta'])
        self.optimizer_beta.load_state_dict(state['optimizer_beta'])

        return state['aux']

    def resume(self, model_path):
        aux = self.load_checkpoint(model_path)
        return aux

    def learn(self, n_interval, n_tot):

        self.beta_net.train()
        self.beta_target.train()

        self.pi_net.train()
        self.pi_target.train()

        self.vb_net.train()
        self.vb_target.train()

        self.q_net.train()
        self.q_target.train()

        self.qb_net.train()
        self.qb_target.train()

        results = {
            'n': [],
            'loss_v_beta': [],
            'loss_q_beta': [],
            'loss_beta': [],
            'loss_v_pi': [],
            'loss_q_pi': [],
            'loss_pi': []
        }

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True),
                             requires_grad=False)

            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)

            r = Variable(sample['r'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            r_mc = Variable(sample['f'].cuda(async=True).unsqueeze(1),
                            requires_grad=False)

            t = Variable(sample['t'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].cuda(async=True), requires_grad=False)

            a_index_unsqueezed = a_index.unsqueeze(1)

            # Behavioral nets
            beta, _ = self.beta_net(s)
            v_beta, _ = self.vb_net(s)
            q_beta, _ = self.qb_net(s)

            # Critic nets
            q_pi, _ = self.q_net(s)

            # Actor nets:
            pi, _ = self.pi_net(s)

            # target networks
            # pi_target, _ = self.pi_target(s)
            q_pi_target, _ = self.q_target(s)

            pi_tag_target, _ = self.pi_target(s_tag)
            q_pi_tag_target, _ = self.q_target(s_tag)

            beta_target, _ = self.beta_target(s)
            v_beta_target, _ = self.vb_target(s)
            # q_beta_target, _ = self.qb_target(s)

            # gather q values
            q_pi = q_pi.gather(1, a_index_unsqueezed)
            q_beta = q_beta.gather(1, a_index_unsqueezed)
            # q_beta_target = q_beta_target.gather(1, a_index_unsqueezed)
            q_pi_target = q_pi_target.gather(1, a_index_unsqueezed)

            # behavioral networks
            # V^{\beta} is learned with MC return
            loss_v_beta = self.loss_v_beta(v_beta, r_mc)

            # beta is learned with policy gradient and Q=1
            loss_beta = self.loss_beta(beta, a_index)

            # MC Q-value return to boost the learning of Q^{\pi}
            loss_q_beta = self.loss_q_beta(q_beta, r_mc)

            # critic importance sampling
            # pi_target_sfm = F.softmax(pi_target, 1)
            #
            # cc = torch.clamp(pi_target_sfm / beta_target_sfm, 0, 1)
            # cc = cc.gather(1, a_index_unsqueezed)

            # Critic evaluation

            # evaluate V^{\pi}(s')
            # V^{\pi}(s') = \sum_{a} Q^{\pi}(s',a) \pi(a|s')
            pi_sfm_tag_target = F.softmax(pi_tag_target + self.mask_beta_batch,
                                          1)
            # consider only common actions

            v_tag = (q_pi_tag_target * pi_sfm_tag_target).sum(1)
            v_tag = v_tag.unsqueeze(1)
            v_tag = v_tag.detach()
            # rho = ((1 - cc) * q_beta_target + cc * (r + (self.discount ** k) * (v_tag * (1 - t)))).detach()

            loss_q_pi = self.loss_q_pi(
                q_pi, r + (self.discount**k) * (v_tag * (1 - t)))

            # actor importance sampling
            pi_sfm = F.softmax(pi, 1)
            beta_target_sfm = F.softmax(beta_target, 1)
            ca = torch.clamp(pi_sfm / beta_target_sfm, 0, 1)
            ca = ca.gather(1, a_index_unsqueezed)

            # Actor evaluation

            loss_pi = self.loss_pi(pi, a_index)

            # total weight is C^{pi/beta}(s,a) * (Q^{pi}(s,a) - V^{beta}(s))

            # if self.balance:
            #     v_beta_bias = (q_beta * beta_sfm).sum(1).unsqueeze(1)
            # else:
            #     v_beta_bias = v_beta

            weight = (ca * (q_pi_target - v_beta_target)).detach()
            loss_pi = (loss_pi * weight.squeeze(1)).mean()

            # Learning part

            self.optimizer_beta.zero_grad()
            loss_beta.backward()
            self.optimizer_beta.step()

            self.optimizer_q_beta.zero_grad()
            loss_q_beta.backward()
            self.optimizer_q_beta.step()

            self.optimizer_v_beta.zero_grad()
            loss_v_beta.backward()
            self.optimizer_v_beta.step()

            if not self.mc:

                self.optimizer_pi.zero_grad()
                loss_pi.backward()
                self.optimizer_pi.step()

                self.optimizer_q_pi.zero_grad()
                loss_q_pi.backward()
                self.optimizer_q_pi.step()

            J = (ca * q_pi).squeeze(1)

            R = r_mc.abs().mean()
            Q_n = (q_pi / R).mean()
            # V_n = (v_beta / R).mean()
            LV_n = (loss_v_beta / R**self.norm).mean()**(1 / self.norm)
            LQB_n = (loss_q_beta / R**self.norm).mean()**(1 / self.norm)
            LQ_n = (loss_q_pi / R**self.norm).mean()**(1 / self.norm)
            LPi_n = (J / R).mean()
            LBeta_n = 1 - torch.exp(-loss_beta).mean()

            # add results
            results['loss_beta'].append(LBeta_n.data.cpu().numpy()[0])
            results['loss_v_beta'].append(LV_n.data.cpu().numpy()[0])
            results['loss_q_beta'].append(LQB_n.data.cpu().numpy()[0])
            results['loss_pi'].append(LPi_n.data.cpu().numpy()[0])
            results['loss_v_pi'].append(Q_n.data.cpu().numpy()[0])
            results['loss_q_pi'].append(LQ_n.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                self.q_target.load_state_dict(self.q_net.state_dict())
                self.pi_target.load_state_dict(self.pi_net.state_dict())
                self.beta_target.load_state_dict(self.beta_net.state_dict())
                self.vb_target.load_state_dict(self.vb_net.state_dict())
                self.qb_target.load_state_dict(self.qb_net.state_dict())

            if not (n + 1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step()

            # start training the model with behavioral initialization
            if (n + 1) == self.update_target_interval * 8:
                self.mc = False
                self.q_target.load_state_dict(self.qb_net.state_dict())
                self.q_net.load_state_dict(self.qb_net.state_dict())
                self.pi_net.load_state_dict(self.beta_net.state_dict())
                self.pi_target.load_state_dict(self.beta_net.state_dict())

            if not (n + 1) % n_interval:

                yield results
                self.beta_net.train()
                self.beta_target.train()

                self.pi_net.train()
                self.pi_target.train()

                self.vb_net.train()
                self.vb_target.train()

                self.q_net.train()
                self.q_target.train()

                self.qb_net.train()
                self.qb_target.train()
                results = {key: [] for key in results}

    def test(self, n_interval, n_tot):

        self.beta_net.eval()
        self.beta_target.eval()

        self.pi_net.eval()
        self.pi_target.eval()

        self.vb_net.eval()
        self.vb_target.eval()

        self.q_net.eval()
        self.q_target.eval()

        self.qb_net.eval()
        self.qb_target.eval()

        results = {
            'n': [],
            'act_diff': [],
            'a_agent': [],
            'a_player': [],
            'loss_v_beta': [],
            'loss_q_beta': [],
            'loss_beta': [],
            'loss_v_pi': [],
            'loss_q_pi': [],
            'loss_pi': []
        }

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True),
                             requires_grad=False)

            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)

            r = Variable(sample['r'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            r_mc = Variable(sample['f'].cuda(async=True).unsqueeze(1),
                            requires_grad=False)

            t = Variable(sample['t'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].cuda(async=True), requires_grad=False)

            a_index_unsqueezed = a_index.unsqueeze(1)

            # Behavioral nets
            beta, _ = self.beta_net(s)
            v_beta, _ = self.vb_net(s)
            q_beta, _ = self.qb_net(s)

            # Critic nets
            q_pi, _ = self.q_net(s)

            # Actor nets:
            pi, _ = self.pi_net(s)

            # target networks
            # pi_target, _ = self.pi_target(s)
            q_pi_target, _ = self.q_target(s)

            pi_tag_target, _ = self.pi_target(s_tag)
            q_pi_tag_target, _ = self.q_target(s_tag)

            beta_target, _ = self.beta_target(s)
            v_beta_target, _ = self.vb_target(s)
            # q_beta_target, _ = self.qb_target(s)

            # gather q values
            q_pi = q_pi.gather(1, a_index_unsqueezed)
            q_beta = q_beta.gather(1, a_index_unsqueezed)
            # q_beta_target = q_beta_target.gather(1, a_index_unsqueezed)
            q_pi_target = q_pi_target.gather(1, a_index_unsqueezed)

            # behavioral networks
            # V^{\beta} is learned with MC return
            loss_v_beta = self.loss_v_beta(v_beta, r_mc)

            # beta is learned with policy gradient and Q=1
            loss_beta = self.loss_beta(beta, a_index)

            # MC Q-value return to boost the learning of Q^{\pi}
            loss_q_beta = self.loss_q_beta(q_beta, r_mc)

            # critic importance sampling
            # pi_target_sfm = F.softmax(pi_target, 1)
            #
            # cc = torch.clamp(pi_target_sfm / beta_target_sfm, 0, 1)
            # cc = cc.gather(1, a_index_unsqueezed)

            # Critic evaluation

            # evaluate V^{\pi}(s')
            # V^{\pi}(s') = \sum_{a} Q^{\pi}(s',a) \pi(a|s')
            pi_sfm_tag_target = F.softmax(pi_tag_target + self.mask_beta_batch,
                                          1)
            # consider only common actions

            v_tag = (q_pi_tag_target * pi_sfm_tag_target).sum(1)
            v_tag = v_tag.unsqueeze(1)
            v_tag = v_tag.detach()
            # rho = ((1 - cc) * q_beta_target + cc * (r + (self.discount ** k) * (v_tag * (1 - t)))).detach()

            loss_q_pi = self.loss_q_pi(
                q_pi, r + (self.discount**k) * (v_tag * (1 - t)))

            # actor importance sampling
            pi_sfm = F.softmax(pi, 1)
            beta_target_sfm = F.softmax(beta_target, 1)
            ca = torch.clamp(pi_sfm / beta_target_sfm, 0, 1)
            ca = ca.gather(1, a_index_unsqueezed)

            # Actor evaluation

            loss_pi = self.loss_pi(pi, a_index)

            # total weight is C^{pi/beta}(s,a) * (Q^{pi}(s,a) - V^{beta}(s))

            # if self.balance:
            #     v_beta_bias = (q_beta * beta_sfm).sum(1).unsqueeze(1)
            # else:
            #     v_beta_bias = v_beta

            weight = (ca * (q_pi_target - v_beta_target)).detach()
            loss_pi = (loss_pi * weight.squeeze(1)).mean()

            # collect actions statistics
            a_index_np = a_index.data.cpu().numpy()

            _, beta_index = beta.data.cpu().max(1)
            beta_index = beta_index.numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            # add results
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)

            J = (ca * q_pi).squeeze(1)

            R = r_mc.abs().mean()
            Q_n = (q_pi / R).mean()
            # V_n = (v_beta / R).mean()
            LV_n = (loss_v_beta / R**self.norm).mean()**(1 / self.norm)
            LQB_n = (loss_q_beta / R**self.norm).mean()**(1 / self.norm)
            LQ_n = (loss_q_pi / R**self.norm).mean()**(1 / self.norm)
            LPi_n = (J / R).mean()
            LBeta_n = 1 - torch.exp(-loss_beta).mean()

            # add results
            results['loss_beta'].append(LBeta_n.data.cpu().numpy()[0])
            results['loss_v_beta'].append(LV_n.data.cpu().numpy()[0])
            results['loss_q_beta'].append(LQB_n.data.cpu().numpy()[0])
            results['loss_pi'].append(LPi_n.data.cpu().numpy()[0])
            results['loss_v_pi'].append(Q_n.data.cpu().numpy()[0])
            results['loss_q_pi'].append(LQ_n.data.cpu().numpy()[0])
            results['n'].append(n)

            if not (n + 1) % n_interval:

                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.beta_net.eval()
                self.beta_target.eval()

                self.pi_net.eval()
                self.pi_target.eval()

                self.vb_net.eval()
                self.vb_target.eval()

                self.q_net.eval()
                self.q_target.eval()

                self.qb_net.eval()
                self.qb_target.eval()
                results = {key: [] for key in results}

    def play(self, n_tot, action_offset, player):

        self.beta_net.eval()
        self.beta_target.eval()

        self.pi_net.eval()
        self.pi_target.eval()

        self.vb_net.eval()
        self.vb_target.eval()

        self.q_net.eval()
        self.q_target.eval()

        self.qb_net.eval()
        self.qb_target.eval()

        env = Env(action_offset)

        n_human = 90

        episodes = list(self.data.keys())
        random.shuffle(episodes)
        humans_trajectories = iter(episodes)

        for i in range(n_tot):

            env.reset()
            trajectory = self.data[next(humans_trajectories)]
            choices = np.arange(self.global_action_space, dtype=np.int)
            random_choices = self.mask_q.data.cpu().numpy()
            random_choices = random_choices / random_choices.sum()

            j = 0

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)

                if player is 'beta':
                    pi, _ = self.beta_net(s)
                    pi = pi.squeeze(0)
                    self.greedy = False

                elif player is 'q_b':
                    pi, _ = self.qb_net(s)
                    pi = pi.squeeze(0)
                    self.greedy = True

                elif player is 'pi':
                    pi, _ = self.pi_net(s)
                    pi = pi.squeeze(0)
                    self.greedy = False

                elif player is 'q_pi':
                    pi, _ = self.q_net(s)
                    pi = pi.squeeze(0)
                    self.greedy = True

                else:
                    raise NotImplementedError

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    eps = np.random.rand()
                    # eps = 1
                    # a = np.random.choice(choices)
                    if self.greedy:
                        if eps > 0.01:
                            a = (pi * self.mask_q).data.cpu().numpy()
                            a = np.argmax(a)
                        else:
                            a = np.random.choice(choices, p=random_choices)
                    else:
                        a = F.softmax(pi + self.mask_beta,
                                      dim=0).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                env.step(a)

                j += 1

            yield {'score': env.score, 'frames': j}

        raise StopIteration

    def play_episode(self, n_tot):

        self.beta_net.eval()
        self.beta_target.eval()

        self.pi_net.eval()
        self.pi_target.eval()

        self.vb_net.eval()
        self.vb_target.eval()

        self.q_net.eval()
        self.q_target.eval()

        self.qb_net.eval()
        self.qb_target.eval()

        env = Env()

        n_human = 120
        humans_trajectories = iter(self.data)
        softmax = torch.nn.Softmax()

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]
            choices = np.arange(self.global_action_space, dtype=np.int)
            mask = Variable(torch.FloatTensor(
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0]),
                            requires_grad=False).cuda()
            j = 0
            temp = 1

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)

                beta, phi = self.beta_net(s)
                pi, _ = self.pi_net(s)
                q, _ = self.q_net(s)
                vb, _ = self.vb_net(s)

                pi = beta.squeeze(0)
                self.greedy = False

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    # eps = np.random.rand()
                    eps = 1
                    # a = np.random.choice(choices)
                    if self.greedy and eps > 0.01:
                        a = pi.data.cpu().numpy()
                        a = np.argmax(a)
                    else:
                        a = softmax(pi / temp).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                q = q[0, a]
                q = q.squeeze(0)

                env.step(a)

                yield {
                    'o': env.s.cpu().numpy(),
                    'v': vb.squeeze(0).data.cpu().numpy(),
                    'vb': vb.squeeze(0).data.cpu().numpy(),
                    'qb': q.squeeze(0).data.cpu().numpy(),
                    # 's': x[0, :512].data.cpu().numpy(),
                    'score': env.score,
                    'beta': pi.data.cpu().numpy(),
                    'phi': phi.squeeze(0).data.cpu().numpy(),
                    'q': q.squeeze(0).data.cpu().numpy()
                }

                j += 1

        raise StopIteration
示例#6
0
    def __init__(self):

        super(LfdAgent, self).__init__()

        # demonstration source
        self.meta, self.data = preprocess_demonstrations()
        self.meta = divide_dataset(self.meta)

        # datasets
        self.train_dataset = DemonstrationMemory("train", self.meta, self.data)
        self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
        self.test_dataset = DemonstrationMemory("test", self.meta, self.data)
        self.full_dataset = DemonstrationMemory("full", self.meta, self.data)

        self.train_sampler = DemonstrationBatchSampler(self.train_dataset, train=True)
        self.val_sampler = DemonstrationBatchSampler(self.train_dataset, train=False)
        self.test_sampler = DemonstrationBatchSampler(self.test_dataset, train=False)
        self.episodic_sampler = SequentialDemonstrationSampler(self.full_dataset)

        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_sampler=self.train_sampler,
                                                        num_workers=args.cpu_workers, pin_memory=True, drop_last=False)
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_sampler=self.test_sampler,
                                                       num_workers=args.cpu_workers, pin_memory=True, drop_last=False)
        self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_sampler=self.val_sampler,
                                                      num_workers=args.cpu_workers, pin_memory=True, drop_last=False)

        self.episodic_loader = torch.utils.data.DataLoader(self.full_dataset, sampler=self.episodic_sampler,
                                                           batch_size=self.batch, num_workers=args.cpu_workers)

        # set learn validate test play parameters based on arguments
        # configure learning
        if not args.value_advantage:
            self.learn = self.learn_q
            self.test = self.test_q
            self.player = QPlayer
            self.agent_type = 'q'
            # loss function and optimizer

            if self.l1_loss:
                self.loss_fn = torch.nn.L1Loss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l1
            else:
                self.loss_fn = torch.nn.MSELoss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l2

            # Choose a model acording to the configurations
            models = {(0,): DQN, (1,): DQNDueling}
            Model = models[(self.dueling,)]

            self.model_single = Model(self.action_space)
            self.target_single = Model(self.action_space)

        else:

            if args.value_only:
                self.alpha_v, self.alpha_a = 1, 0
            else:
                self.alpha_v, self.alpha_a = 1, 1

            if self.l1_loss:
                self.loss_fn_value = torch.nn.L1Loss(size_average=True)
                self.loss_fn_advantage = torch.nn.L1Loss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l1
            else:
                self.loss_fn_value = torch.nn.MSELoss(size_average=True)
                self.loss_fn_advantage = torch.nn.MSELoss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l2

            if not args.input_actions:
                self.learn = self.learn_va
                self.test = self.test_va
                self.player = AVPlayer
                self.agent_type = 'av'
                self.model_single = DVAN_ActionOut(self.action_space)
                self.target_single = DVAN_ActionOut(self.action_space)

            else:
                self.learn = self.learn_ava
                self.test = self.test_ava
                self.player = AVAPlayer
                self.agent_type = 'ava'
                self.model_single = DVAN_ActionIn(3)
                self.target_single = DVAN_ActionIn(3)

                # model specific parameters
                self.action_space = consts.action_space
                self.excitation = torch.LongTensor(consts.excitation_map)
                self.excitation_length = self.excitation.shape[0]
                self.mask = torch.LongTensor(consts.excitation_mask[args.game])
                self.mask_dup = self.mask.unsqueeze(0).repeat(self.action_space, 1)

                actions = Variable(self.mask_dup * self.excitation, requires_grad=False)
                actions = actions.cuda()

                self.actions_matrix = actions.unsqueeze(0)
                self.actions_matrix = self.actions_matrix.repeat(self.batch, 1, 1).float()
                self.reverse_excitation_index = consts.reverse_excitation_index

        if not args.play:
            self.play = self.dummy_play
        elif args.gpu_workers == 0:
            self.play = self.single_play
        else:
            self.play = self.multi_play

        q_functions = {(0, 0): self.simple_q, (0, 1): self.double_q, (1, 0): self.simple_on_q,
                       (1, 1): self.simple_on_q}
        self.q_estimator = q_functions[(self.double_q, self.on_policy)]

        # configure learning
        if args.cuda:
            self.model_single = self.model_single.cuda()
            self.target_single = self.target_single.cuda()
            self.model = torch.nn.DataParallel(self.model_single)
            self.target = torch.nn.DataParallel(self.target_single)
        else:
            self.model = self.model_single
            self.target = self.target_single
        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer = LfdAgent.set_optimizer(self.model.parameters())
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, self.decay)
示例#7
0
class ACDQNLSTMAgent(Agent):

    def __init__(self, load_dataset=True):

        super(ACDQNLSTMAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta, self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset, train=True)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset, train=False)

            self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_sampler=self.train_sampler,
                                                            num_workers=args.cpu_workers, pin_memory=True, drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_sampler=self.test_sampler,
                                                           num_workers=args.cpu_workers, pin_memory=True, drop_last=False)

        self.loss_v_beta = torch.nn.L1Loss(size_average=True, reduce=True)
        self.loss_q_beta = torch.nn.L1Loss(size_average=True, reduce=True)

        self.loss_v_pi = torch.nn.L1Loss(size_average=True, reduce=True)
        self.loss_q_pi = torch.nn.L1Loss(size_average=True, reduce=True)

        self.loss_p = torch.nn.L1Loss(size_average=True, reduce=True)

        self.histogram = torch.from_numpy(self.meta['histogram']).float()
        weights = self.histogram.max() / self.histogram
        weights = torch.clamp(weights, 0, 10).cuda()

        self.loss_beta = torch.nn.CrossEntropyLoss(size_average=True)
        self.loss_pi = torch.nn.CrossEntropyLoss(reduce=False)

        # actor critic setting

        self.model_b_single = ACDQNLSTM().cuda()
        self.model_single = ACDQNLSTM().cuda()
        self.target_single = ACDQNLSTM().cuda()

        if self.parallel:
            self.model_b = torch.nn.DataParallel(self.model_b_single)
            self.model = torch.nn.DataParallel(self.model_single)
            self.target = torch.nn.DataParallel(self.target_single)
        else:
            self.model_b = self.model_b_single
            self.model = self.model_single
            self.target = self.target_single

        self.target_single.reset_target()
        # configure learning

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER


        self.optimizer_q_pi = ACDQNLSTMAgent.set_optimizer(self.model.parameters(), 0.0002)
        self.scheduler_q_pi = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_q_pi, self.decay)

        self.optimizer_pi = ACDQNLSTMAgent.set_optimizer(self.model.parameters(), 0.0002)
        self.scheduler_pi = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_pi, self.decay)

        self.optimizer_q_beta = ACDQNLSTMAgent.set_optimizer(self.model_b.parameters(), 0.0002)
        self.scheduler_q_beta = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_q_beta, self.decay)

        self.optimizer_beta = ACDQNLSTMAgent.set_optimizer(self.model_b.parameters(), 0.0008)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(self.optimizer_beta, self.decay)

        actions = torch.LongTensor(consts.hotvec_matrix).cuda()
        self.actions_matrix = Variable(actions.unsqueeze(0), requires_grad=False)

        self.batch_actions_matrix = self.actions_matrix.repeat(self.batch, 1, 1)

        self.batch_range = np.arange(self.batch)
        self.zero = Variable(torch.zeros(1))
        self.a_post_mat = Variable(torch.from_numpy(consts.a_post_mat).long(), requires_grad=False).cuda()
        self.a_post_mat = self.a_post_mat.unsqueeze(0).repeat(self.batch, 1, 1)

    def save_checkpoint(self, path, aux=None):

        state = {'model_b': self.model_b.state_dict(),
                 'model': self.model.state_dict(),
                 'target': self.target.state_dict(),
                 'optimizer_q_pi': self.optimizer_q_pi.state_dict(),
                 'optimizer_pi': self.optimizer_pi.state_dict(),
                 'optimizer_q_beta': self.optimizer_q_beta.state_dict(),
                 'optimizer_beta': self.optimizer_beta.state_dict(),
                 'aux': aux}

        torch.save(state, path)

    def load_checkpoint(self, path):

        state = torch.load(path)
        self.model_b.load_state_dict(state['model_b'])
        self.model.load_state_dict(state['model'])
        self.target.load_state_dict(state['target'])
        self.optimizer_q_pi.load_state_dict(state['optimizer_q_pi'])
        self.optimizer_pi.load_state_dict(state['optimizer_pi'])
        self.optimizer_q_beta.load_state_dict(state['optimizer_q_beta'])
        self.optimizer_beta.load_state_dict(state['optimizer_beta'])

        return state['aux']

    def resume(self, model_path):

        aux = self.load_checkpoint(model_path)
        # self.update_target()
        return aux

    def update_target(self):
        self.target.load_state_dict(self.model.state_dict())

    def learn(self, n_interval, n_tot):

        self.model_b.train()
        self.model.train()
        self.target.eval()

        results = {'n': [], 'loss_v_beta': [], 'loss_q_beta': [], 'loss_beta': [],
                   'loss_v_pi': [], 'loss_q_pi': [], 'loss_pi': []}

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True), requires_grad=False)

            a_pre = Variable(sample['horizon_pre'].cuda(async=True), requires_grad=False)
            a_pre_tag = Variable(sample['horizon_pre_tag'].cuda(async=True), requires_grad=False)

            a_index = Variable(sample['matched_option'].cuda(async=True), requires_grad=False)

            r = Variable(sample['r'].cuda(async=True).unsqueeze(1), requires_grad=False)
            r_mc = Variable(sample['f'].cuda(async=True).unsqueeze(1), requires_grad=False)

            t = Variable(sample['t'].cuda(async=True).unsqueeze(1), requires_grad=False)
            k = Variable(sample['k'].cuda(async=True), requires_grad=False)

            beta, q_beta, _, = self.model_b(s, a_pre, self.a_post_mat)

            q_beta = q_beta.squeeze(2)
            q_beta = q_beta.gather(1, a_index)

            loss_beta = self.loss_beta(beta, a_index.squeeze(1))
            loss_v_beta = r_mc.abs().mean()
            loss_q_beta = self.loss_q_beta(q_beta, r_mc)

            pi, q_pi, _ = self.model(s, a_pre, self.a_post_mat)
            pi_tag, q_tag, _ = self.target(s_tag, a_pre_tag, self.a_post_mat)

            q_pi = q_pi.squeeze(2)
            q_pi = q_pi.gather(1, a_index)

            # ignore negative q:
            if self.double_q:
                _,  q_tag_model, _ = self.model(s_tag, a_pre_tag, self.a_post_mat)
                _, a_tag = F.relu(q_tag_model).max(1)
                q_max = q_tag.gather(1, a_tag.unsqueeze(1))
                q_max = F.relu(q_max).squeeze(1)
            else:
                q_max, _ = F.relu(q_tag).max(1)

            q_max = q_max.detach()

            loss_v_pi = (r + (self.discount ** k) * (q_max * (1 - t))).abs().mean()
            loss_q_pi = self.loss_q_pi(q_pi, r + (self.discount ** k) * (q_max * (1 - t)))
            loss_pi = self.loss_pi(pi, a_index.squeeze(1))

            beta_sfm = F.softmax(beta, 1)
            pi_sfm = F.softmax(pi, 1)
            c = torch.clamp(pi_sfm/beta_sfm, 0, 1)

            c = c.gather(1, a_index)
            weight = (c * q_pi).detach()
            loss_pi = (loss_pi * weight.squeeze(1)).sum()

            self.optimizer_beta.zero_grad()
            loss_beta.backward(retain_graph=True)
            self.optimizer_beta.step()

            self.optimizer_q_beta.zero_grad()
            loss_q_beta.backward()
            self.optimizer_q_beta.step()

            self.optimizer_pi.zero_grad()
            loss_pi.backward(retain_graph=True)
            self.optimizer_pi.step()

            # self.optimizer_v_pi.zero_grad()
            # loss_v_pi.backward()
            # self.optimizer_v_pi.step()

            self.optimizer_q_pi.zero_grad()
            loss_q_pi.backward()
            self.optimizer_q_pi.step()

            R = (r_mc ** 1).mean()

            # add results
            results['loss_beta'].append(loss_beta.data.cpu().numpy()[0])
            results['loss_v_beta'].append((loss_v_beta / R).data.cpu().numpy()[0])
            results['loss_q_beta'].append((loss_q_beta / R).data.cpu().numpy()[0])
            # results['loss_q_beta'].append(loss_p.data.cpu().numpy()[0])
            results['loss_pi'].append(loss_pi.data.cpu().numpy()[0])
            results['loss_v_pi'].append((loss_v_pi / R).data.cpu().numpy()[0])
            results['loss_q_pi'].append((loss_q_pi / R).data.cpu().numpy()[0])
            results['n'].append(n)


            if not n % self.update_target_interval:
                self.update_target()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (n+1) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            if not (n+1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step()

            # start training the model with behavioral initialization
            if (n+1) == self.update_n_steps_interval:
                self.target_single.reset_target()
                self.model.load_state_dict(self.model_b.state_dict())

            if not (n+1) % n_interval:
                yield results
                self.model_b.train()
                self.model.train()
                self.target.eval()
                results = {key: [] for key in results}

    def test(self, n_interval, n_tot):

        self.model.eval()
        self.target.eval()
        self.model_b.eval()

        results = {'n': [], 'act_diff': [], 'a_agent': [], 'a_player': [],
                   'loss_v_beta': [], 'loss_q_beta': [], 'loss_beta': [],
                   'loss_v_pi': [], 'loss_q_pi': [], 'loss_pi': []}

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True), requires_grad=False)

            a_pre = Variable(sample['horizon_pre'].cuda(async=True), requires_grad=False)
            a_pre_tag = Variable(sample['horizon_pre_tag'].cuda(async=True), requires_grad=False)

            a_index = Variable(sample['matched_option'].cuda(async=True), requires_grad=False)

            r = Variable(sample['r'].cuda(async=True).unsqueeze(1), requires_grad=False)
            r_mc = Variable(sample['f'].cuda(async=True).unsqueeze(1), requires_grad=False)

            t = Variable(sample['t'].cuda(async=True).unsqueeze(1), requires_grad=False)
            k = Variable(sample['k'].cuda(async=True), requires_grad=False)

            beta, q_beta, _, = self.model_b(s, a_pre, self.a_post_mat)

            q_beta = q_beta.squeeze(2)
            q_beta = q_beta.gather(1, a_index)

            loss_beta = self.loss_beta(beta, a_index.squeeze(1))
            loss_v_beta = r_mc.abs().mean()
            loss_q_beta = self.loss_q_beta(q_beta, r_mc)

            pi, q_pi, _ = self.model(s, a_pre, self.a_post_mat)
            pi_tag, q_tag, _ = self.target(s_tag, a_pre_tag, self.a_post_mat)

            q_pi = q_pi.squeeze(2)
            q_pi = q_pi.gather(1, a_index)

            # ignore negative q:
            if self.double_q:
                _,  q_tag_model, _ = self.model(s_tag, a_pre_tag, self.a_post_mat)
                _, a_tag = F.relu(q_tag_model).max(1)
                q_max = q_tag.gather(1, a_tag.unsqueeze(1))
                q_max = F.relu(q_max).squeeze(1)
            else:
                q_max, _ = F.relu(q_tag).max(1)

            q_max = q_max.detach()

            loss_v_pi = (r + (self.discount ** k) * (q_max * (1 - t))).abs().mean()
            loss_q_pi = self.loss_q_pi(q_pi, r + (self.discount ** k) * (q_max * (1 - t)))
            loss_pi = self.loss_pi(pi, a_index.squeeze(1))

            beta_sfm = F.softmax(beta, 1)
            pi_sfm = F.softmax(pi, 1)
            c = torch.clamp(pi_sfm/beta_sfm, 0, 1)

            c = c.gather(1, a_index)
            weight = (c * q_pi).detach()
            loss_pi = (loss_pi * weight.squeeze(1)).sum()

            # collect actions statistics
            a_index_np = a_index.data.cpu().numpy()

            _, beta_index = beta.data.cpu().max(1)
            beta_index = beta_index.numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            R = (r_mc ** 1).mean()

            # add results
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)
            results['loss_beta'].append(loss_beta.data.cpu().numpy()[0])
            results['loss_v_beta'].append((loss_v_beta / R).data.cpu().numpy()[0])
            results['loss_v_pi'].append((loss_v_pi / R).data.cpu().numpy()[0])
            results['loss_q_beta'].append((loss_q_beta / R).data.cpu().numpy()[0])
            # results['loss_q_beta'].append(loss_p.data.cpu().numpy()[0])
            results['loss_pi'].append(loss_pi.data.cpu().numpy()[0])
            results['loss_q_pi'].append((loss_q_pi / R).data.cpu().numpy()[0])
            results['n'].append(n)

            if not (n+1) % n_interval:
                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.model.eval()
                self.target.eval()
                self.model_b.eval()
                results = {key: [] for key in results}

    def play_stochastic(self, n_tot):
        raise NotImplementedError

    def play_episode(self, n_tot):

        self.model.eval()
        self.model_b.eval()
        env = Env()

        n_human = 120
        humans_trajectories = iter(self.data)
        softmax = torch.nn.Softmax()

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]
            choices = np.arange(self.global_action_space, dtype=np.int)
            mask = Variable(torch.FloatTensor([0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
                             requires_grad=False).cuda()
            j = 0
            temp = 1

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)

                beta, vb, qb, _, _ = self.model_b(s, self.actions_matrix)
                pi, v, q, adv, x = self.model(s, self.actions_matrix, beta.detach())

                pi = pi.squeeze(0)
                self.greedy = False

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    eps = np.random.rand()
                    # a = np.random.choice(choices)
                    if self.greedy and eps > 0.1:
                        a = pi.data.cpu().numpy()
                        a = np.argmax(a)
                    else:
                        a = softmax(pi/temp).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                q = q[0, a, 0]
                q = q.squeeze(0)

                qb = qb[0, a, 0]
                qb = qb.squeeze(0)

                env.step(a)

                yield {'o': env.s.cpu().numpy(),
                       'v': v.squeeze(0).data.cpu().numpy(),
                       'vb': vb.squeeze(0).data.cpu().numpy(),
                       'qb': qb.squeeze(0).data.cpu().numpy(),
                       's': x[0, :512].data.cpu().numpy(),
                       'score': env.score,
                       'beta': pi.data.cpu().numpy(),
                       'phi': x[0, :512].data.cpu().numpy(),
                       'q': q.squeeze(0).data.cpu().numpy()}

                j += 1

        raise StopIteration

    def policy(self, vs, vl, beta, qs, ql):
        pass
示例#8
0
    def __init__(self, load_dataset=True):

        super(BehavioralHotAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset,
                batch_sampler=self.val_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

            self.episodic_loader = torch.utils.data.DataLoader(
                self.full_dataset,
                sampler=self.episodic_sampler,
                batch_size=self.batch,
                num_workers=args.cpu_workers)

        if self.l1_loss:
            self.loss_fn_value = torch.nn.L1Loss(size_average=True)
            self.loss_fn_q = torch.nn.L1Loss(size_average=True)
        else:
            self.loss_fn_value = torch.nn.MSELoss(size_average=True)
            self.loss_fn_q = torch.nn.MSELoss(size_average=True)

        self.loss_fn_r = torch.nn.MSELoss(size_average=True)
        self.loss_fn_p = torch.nn.L1Loss(size_average=True)

        if self.weight_by_expert:
            self.loss_fn_beta = torch.nn.CrossEntropyLoss(reduce=False)
        else:
            self.loss_fn_beta = torch.nn.CrossEntropyLoss(reduce=True)

        # alpha weighted sum

        self.alpha_v = 1  # 1 / 0.02
        self.alpha_b = 1  # 1 / 0.7

        self.alpha_r = 1  # 1 / 0.7
        self.alpha_p = 0  # 1 / 0.7
        self.alpha_q = 1

        self.model = BehavioralHotNet()
        self.model.cuda()

        # configure learning

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer_v = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr)
        self.scheduler_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v, self.decay)

        self.optimizer_beta = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_beta)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        self.optimizer_q = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_q)
        self.scheduler_q = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q, self.decay)

        self.optimizer_r = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_r)
        self.scheduler_r = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_r, self.decay)

        self.optimizer_p = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_p)
        self.scheduler_p = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_p, self.decay)

        self.episodic_evaluator = self.dummy_episodic_evaluator

        actions = torch.FloatTensor(consts.hotvec_matrix) / (3**(0.5))
        actions = Variable(actions, requires_grad=False).cuda()

        self.actions_matrix = actions.unsqueeze(0)
示例#9
0
class LfdAgent(Agent):

    def __init__(self):

        super(LfdAgent, self).__init__()

        # demonstration source
        self.meta, self.data = preprocess_demonstrations()
        self.meta = divide_dataset(self.meta)

        # datasets
        self.train_dataset = DemonstrationMemory("train", self.meta, self.data)
        self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
        self.test_dataset = DemonstrationMemory("test", self.meta, self.data)
        self.full_dataset = DemonstrationMemory("full", self.meta, self.data)

        self.train_sampler = DemonstrationBatchSampler(self.train_dataset, train=True)
        self.val_sampler = DemonstrationBatchSampler(self.train_dataset, train=False)
        self.test_sampler = DemonstrationBatchSampler(self.test_dataset, train=False)
        self.episodic_sampler = SequentialDemonstrationSampler(self.full_dataset)

        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_sampler=self.train_sampler,
                                                        num_workers=args.cpu_workers, pin_memory=True, drop_last=False)
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_sampler=self.test_sampler,
                                                       num_workers=args.cpu_workers, pin_memory=True, drop_last=False)
        self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_sampler=self.val_sampler,
                                                      num_workers=args.cpu_workers, pin_memory=True, drop_last=False)

        self.episodic_loader = torch.utils.data.DataLoader(self.full_dataset, sampler=self.episodic_sampler,
                                                           batch_size=self.batch, num_workers=args.cpu_workers)

        # set learn validate test play parameters based on arguments
        # configure learning
        if not args.value_advantage:
            self.learn = self.learn_q
            self.test = self.test_q
            self.player = QPlayer
            self.agent_type = 'q'
            # loss function and optimizer

            if self.l1_loss:
                self.loss_fn = torch.nn.L1Loss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l1
            else:
                self.loss_fn = torch.nn.MSELoss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l2

            # Choose a model acording to the configurations
            models = {(0,): DQN, (1,): DQNDueling}
            Model = models[(self.dueling,)]

            self.model_single = Model(self.action_space)
            self.target_single = Model(self.action_space)

        else:

            if args.value_only:
                self.alpha_v, self.alpha_a = 1, 0
            else:
                self.alpha_v, self.alpha_a = 1, 1

            if self.l1_loss:
                self.loss_fn_value = torch.nn.L1Loss(size_average=True)
                self.loss_fn_advantage = torch.nn.L1Loss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l1
            else:
                self.loss_fn_value = torch.nn.MSELoss(size_average=True)
                self.loss_fn_advantage = torch.nn.MSELoss(size_average=True)
                self.individual_loss_fn = self.individual_loss_fn_l2

            if not args.input_actions:
                self.learn = self.learn_va
                self.test = self.test_va
                self.player = AVPlayer
                self.agent_type = 'av'
                self.model_single = DVAN_ActionOut(self.action_space)
                self.target_single = DVAN_ActionOut(self.action_space)

            else:
                self.learn = self.learn_ava
                self.test = self.test_ava
                self.player = AVAPlayer
                self.agent_type = 'ava'
                self.model_single = DVAN_ActionIn(3)
                self.target_single = DVAN_ActionIn(3)

                # model specific parameters
                self.action_space = consts.action_space
                self.excitation = torch.LongTensor(consts.excitation_map)
                self.excitation_length = self.excitation.shape[0]
                self.mask = torch.LongTensor(consts.excitation_mask[args.game])
                self.mask_dup = self.mask.unsqueeze(0).repeat(self.action_space, 1)

                actions = Variable(self.mask_dup * self.excitation, requires_grad=False)
                actions = actions.cuda()

                self.actions_matrix = actions.unsqueeze(0)
                self.actions_matrix = self.actions_matrix.repeat(self.batch, 1, 1).float()
                self.reverse_excitation_index = consts.reverse_excitation_index

        if not args.play:
            self.play = self.dummy_play
        elif args.gpu_workers == 0:
            self.play = self.single_play
        else:
            self.play = self.multi_play

        q_functions = {(0, 0): self.simple_q, (0, 1): self.double_q, (1, 0): self.simple_on_q,
                       (1, 1): self.simple_on_q}
        self.q_estimator = q_functions[(self.double_q, self.on_policy)]

        # configure learning
        if args.cuda:
            self.model_single = self.model_single.cuda()
            self.target_single = self.target_single.cuda()
            self.model = torch.nn.DataParallel(self.model_single)
            self.target = torch.nn.DataParallel(self.target_single)
        else:
            self.model = self.model_single
            self.target = self.target_single
        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer = LfdAgent.set_optimizer(self.model.parameters())
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, self.decay)

    @staticmethod
    def individual_loss_fn_l2(argument):
        return abs(argument.data.cpu().numpy())**2

    def individual_loss_fn_l1(argument):
        return abs(argument.data.cpu().numpy())

    def resume(self, model_path):

        aux = self.load_checkpoint(model_path)
        self.update_target()
        return aux

    def update_target(self):
        self.target.load_state_dict(self.model.state_dict())

    # double q-learning
    def double_q(self, s_tag, q_target, a_tag, f):
        q_tag = self.model(s_tag)
        _, a_tag = q_tag.max(1)
        return q_target[range(self.batch), a_tag.data].data

    def simple_q(self, s_tag, q_target, a_tag, f):
        return q_target.max(1)[0].data

    def simple_on_q(self, s_tag, q_target, a_tag, f):
        return q_target[range(self.batch), a_tag.data].data

    def complex_q(self, s_tag, q_target, a_tag, f):
        q_tag = self.model(s_tag)

        le = self.margin * Variable(torch.ones(self.batch, self.action_space).cuda())
        le[range(self.batch), a_tag.data] = 0

        _, a_max = (q_tag + le).max(1)
        off_policy = q_target[range(self.batch), a_max.data]

        on_policy = q_target[range(self.batch), a_tag.data]
        alpha = torch.clamp(f / 2, 0, 1)  # * np.exp(- (n / (32 * n_interval)))

        return torch.mul(on_policy.data, alpha) + torch.mul((1 - alpha), off_policy.data)

    def dummy_episodic_evaluator(self):
        while True:
            yield {'q_diff': torch.zeros(100), 'a_agent': torch.zeros(100, self.action_space), 'a_player': torch.zeros(100).long()}

    def episodic_evaluator(self):

        self.model.eval()
        self.target.eval()
        results = {'q_diff': [], 'a_agent': [], 'a_player': []}

        for n, sample in tqdm(enumerate(self.episodic_loader)):

            is_final = sample['is_final']
            final_indicator, final_index = is_final.max(0)
            final_indicator = final_indicator.numpy()[0]
            final_index = final_index.numpy()[0]+1 if final_indicator else self.batch

            s = Variable(sample['s'][:final_index].cuda(), requires_grad=False)
            a = Variable(sample['a'][:final_index].cuda().unsqueeze(1), requires_grad=False)
            f = sample['f'][:final_index].float().cuda()
            base = sample['base'][:final_index].float()
            r = Variable(sample['r'][:final_index].float().cuda(), requires_grad=False)
            a_index = Variable(sample['a_index'][:final_index].cuda().unsqueeze(1), requires_grad=False)

            if self.agent_type == 'q':
                q = self.model(s)
                q_best, a_best = q.cpu().data.max(1)
                q_diff = r.data.cpu() - q_best if self.myopic else (f.cpu() - base) - q_best
                a_player = a.squeeze(1).data.cpu()
                a_agent = q.data.cpu()

            else:

                if self.agent_type == 'av':
                    value, advantage = self.model(s)
                elif self.agent_type == 'ava':
                    value, advantage = self.model(s, self.actions_matrix[:final_index])
                else:
                    raise NotImplementedError
                value = value.squeeze(1)
                q_diff = r.data.cpu() - value.data.cpu() if self.myopic else (f.cpu() - base) - value.data.cpu()
                a_player = a_index.squeeze(1).data.cpu()
                a_agent = advantage.data.cpu()

            # add results
            results['q_diff'].append(q_diff)
            results['a_agent'].append(a_agent)
            results['a_player'].append(a_player)

            if final_indicator:
                results['q_diff'] = torch.cat(results['q_diff'])
                results['a_agent'] = torch.cat(results['a_agent'])
                results['a_player'] = torch.cat(results['a_player'])
                yield results
                self.model.eval()
                self.target.eval()
                results = {key: [] for key in results}

    def saliency_map(self, s, a):
        self.model.eval()
        pass

    def learn_q(self, n_interval, n_tot):

        self.model.train()
        self.target.eval()
        results = {'loss': [], 'n': []}

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda().unsqueeze(1), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(), requires_grad=False)
            r = Variable(sample['r'].float().cuda(), requires_grad=False)
            t = Variable(sample['t'].float().cuda(), requires_grad=False)
            k = Variable(sample['k'].float().cuda(), requires_grad=False)
            f = sample['f'].float().cuda()
            indexes = sample['i']

            q = self.model(s)
            q_a = q.gather(1, a)[:, 0]

            q_target = self.target(s_tag)

            max_q_target = Variable(self.q_estimator(s_tag, q_target, a_tag, f), requires_grad=False)
            loss = self.loss_fn(q_a, r + (self.discount ** k) * (max_q_target * (1 - t)))

            # calculate the td error for the priority replay
            if self.prioritized_replay:
                argument = r + (self.discount ** k) * (max_q_target * (1 - t)) - q_a
                individual_loss = LfdAgent.individual_loss_fn(argument)
                self.train_dataset.update_td_error(indexes.numpy(), individual_loss)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # add results
            results['loss'].append(loss.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                self.update_target()
                self.scheduler.step()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (n+1) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            # update a global n_step parameter
            if not (n+1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step(n+1)

            if not (n+1) % n_interval:
                results['n_steps'] = self.train_dataset.n_steps
                yield results
                self.model.train()
                self.target.eval()
                results = {key: [] for key in results}

    def learn_ava(self, n_interval, n_tot):

        self.model.train()
        self.target.train()
        results = {'loss': [], 'n': []}

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda(), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(), requires_grad=False)
            r = Variable(sample['r'].float().cuda(), requires_grad=False)
            t = Variable(sample['t'].float().cuda(), requires_grad=False)
            k = Variable(sample['k'].float().cuda(), requires_grad=False)
            f = sample['f'].float().cuda()
            indexes = sample['i']

            value, advantage = self.model(s, a)
            value_tag = self.target(s_tag)

            value_target = self.target(s)
            value_tag = Variable(value_tag.data, requires_grad=False)
            value_target = Variable(value_target.data, requires_grad=False)

            value = value.squeeze(1)
            value_tag = value_tag.squeeze(1)
            advantage = advantage.squeeze(1)
            value_target = value_target.squeeze(1)
            loss_v = self.loss_fn_value(value, r + (self.discount ** k) * (value_tag * (1 - t)))
            loss_a = self.loss_fn_advantage(advantage, r + (self.discount ** k) * (value_tag * (1 - t)) - value_target)
            loss = self.alpha_v * loss_v + self.alpha_a * loss_a

            # # calculate the td error for the priority replay
            # if self.prioritized_replay:
            #     argument = r + (self.discount ** k) * (max_q_target * (1 - t)) - q_a
            #     individual_loss = LfdAgent.individual_loss_fn(argument)
            #     self.train_dataset.update_td_error(indexes.numpy(), individual_loss)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # add results
            results['loss'].append(loss.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                self.update_target()
                self.scheduler.step()

            # if an index is rolled more than once during update_memory_interval period,
            # only the last occurrence affect the
            if not (n+1) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            if not (n+1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step(n+1)

            if not (n+1) % n_interval:
                results['n_steps'] = self.train_dataset.n_steps
                yield results
                self.model.train()
                self.target.eval()
                results = {key: [] for key in results}

    def learn_va(self, n_interval, n_tot):

        self.model.train()
        self.target.eval()
        results = {'loss': [], 'n': []}

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda().unsqueeze(1), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(), requires_grad=False)
            r = Variable(sample['r'].float().cuda(), requires_grad=False)
            t = Variable(sample['t'].float().cuda(), requires_grad=False)
            k = Variable(sample['k'].float().cuda(), requires_grad=False)
            f = sample['f'].float().cuda()
            indexes = sample['i']

            value, advantage = self.model(s)
            value_tag, advantage_tag = self.target(s_tag)
            advantage_a = advantage.gather(1, a)[:, 0]

            value_target, advantage_target = self.target(s)
            value_tag = Variable(value_tag.data, requires_grad=False)
            value_target = Variable(value_target.data, requires_grad=False)

            value = value.squeeze(1)
            value_tag = value_tag.squeeze(1)
            value_target = value_target.squeeze(1)
            loss_v = self.loss_fn_value(value, r + (self.discount ** k) * (value_tag * (1 - t)))
            loss_a = self.loss_fn_advantage(advantage_a, r + (self.discount ** k) * (value_tag * (1 - t)) - value_target)
            loss = self.alpha_v * loss_v + self.alpha_a * loss_a

            # # calculate the td error for the priority replay
            # if self.prioritized_replay:
            #     argument = r + (self.discount ** k) * (max_q_target * (1 - t)) - q_a
            #     individual_loss = LfdAgent.individual_loss_fn(argument)
            #     self.train_dataset.update_td_error(indexes.numpy(), individual_loss)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # add results
            results['loss'].append(loss.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                self.update_target()
                self.scheduler.step()

            # if an index is rolled more than once during update_memory_interval period,
            # only the last occurrence affect the
            if not (n+1) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            if not (n+1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step(n+1)

            if not (n+1) % n_interval:
                results['n_steps'] = self.train_dataset.n_steps
                yield results
                self.model.train()
                self.target.eval()
                results = {key: [] for key in results}

    def test_va(self, n_interval, n_tot):

        self.model.eval()
        self.target.eval()

        results = {'loss': [], 'n': [], 'q_diff': [], 'q': [], 'act_diff': [], 'r': [], 'a_best': []}

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda().unsqueeze(1), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(), requires_grad=False)
            r = Variable(sample['r'].float().cuda(), requires_grad=False)
            t = Variable(sample['t'].float().cuda(), requires_grad=False)
            k = Variable(sample['k'].float().cuda(), requires_grad=False)
            f = sample['f'].float().cuda()
            base = sample['base'].float()

            value, advantage = self.model(s)
            value_tag, advantage_tag = self.target(s_tag)
            advantage_a = advantage.gather(1, a)[:, 0]

            value_target, advantage_target = self.target(s)
            value_tag = Variable(value_tag.data, requires_grad=False)
            value_target = Variable(value_target.data, requires_grad=False)

            advantage_best, a_best = advantage.data.cpu().max(1)

            value = value.squeeze(1)
            value_tag = value_tag.squeeze(1)
            value_target = value_target.squeeze(1)
            loss_v = self.loss_fn_value(value, r + (self.discount ** k) * (value_tag * (1 - t)))
            loss_a = self.loss_fn_advantage(advantage_a, r + (self.discount ** k) * (value_tag * (1 - t)) - value_target)
            loss = self.alpha_v * loss_v + self.alpha_a * loss_a

            if self.myopic:
                v_diff = r.data.cpu() - value.data.cpu()
            else:
                v_diff = (f.cpu() - base) - value.data.cpu()

            act_diff = (a.squeeze(1).data.cpu() == a_best).numpy()

            # add results
            results['loss'].append(loss.data.cpu().numpy()[0])
            results['n'].append(n)
            results['q_diff'].append(v_diff)
            results['q'].append(advantage.data.cpu())
            results['act_diff'].append(act_diff)
            results['r'].append(r.data.cpu())
            results['a_best'].append(a_best)

            if not (n+1) % n_interval:
                results['s'] = s.data.cpu()
                results['s_tag'] = s_tag.data.cpu()
                results['q'] = torch.cat(results['q'])
                results['r'] = torch.cat(results['r'])
                results['q_diff'] = torch.cat(results['q_diff'])
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_best'] = torch.cat(results['a_best'])
                results['n_steps'] = self.test_dataset.n_steps
                yield results
                self.model.eval()
                self.target.eval()
                results = {key: [] for key in results}

    def test_ava(self, n_interval, n_tot):

        self.model.eval()
        self.target.eval()

        results = {'loss': [], 'n': [], 'q_diff': [], 'q': [], 'act_diff': [], 'r': [], 'a_best': []}

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda(), requires_grad=False)
            a_index = Variable(sample['a_index'].cuda().unsqueeze(1), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(), requires_grad=False)
            r = Variable(sample['r'].float().cuda(), requires_grad=False)
            t = Variable(sample['t'].float().cuda(), requires_grad=False)
            k = Variable(sample['k'].float().cuda(), requires_grad=False)
            f = sample['f'].float().cuda()
            base = sample['base'].float()

            # value, advantage = self.eval_ava(s, a)
            value, advantage = self.model(s, self.actions_matrix)
            value_tag = self.target(s_tag)

            value_target = self.target(s)
            value_tag = Variable(value_tag.data, requires_grad=False)
            value_target = Variable(value_target.data, requires_grad=False)

            value = value.squeeze(1)
            value_tag = value_tag.squeeze(1)

            advantage = advantage.squeeze(2)
            advantage_a = advantage.gather(1, a_index)[:, 0]

            value_target = value_target.squeeze(1)

            loss_v = self.loss_fn_value(value, r + (self.discount ** k) * (value_tag * (1 - t)))
            loss_a = self.loss_fn_advantage(advantage_a, r + (self.discount ** k) * (value_tag * (1 - t)) - value_target)
            loss = self.alpha_v * loss_v + self.alpha_a * loss_a

            advantage_best, a_best = advantage.data.cpu().max(1)

            if self.myopic:
                v_diff = r.data.cpu() - value.data.cpu()
            else:
                v_diff = (f.cpu() - base) - value.data.cpu()

            act_diff = (a_index.squeeze(1).data.cpu() == a_best).numpy()

            # add results
            results['loss'].append(loss.data.cpu().numpy()[0])
            results['n'].append(n)
            results['q_diff'].append(v_diff)
            results['q'].append(advantage.data.cpu())
            results['act_diff'].append(act_diff)
            results['r'].append(r.data.cpu())
            results['a_best'].append(a_best)

            if not (n+1) % n_interval:
                results['s'] = s.data.cpu()
                results['s_tag'] = s_tag.data.cpu()
                results['q'] = torch.cat(results['q'])
                results['r'] = torch.cat(results['r'])
                results['q_diff'] = torch.cat(results['q_diff'])
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_best'] = torch.cat(results['a_best'])
                results['n_steps'] = self.test_dataset.n_steps
                yield results
                self.model.eval()
                self.target.eval()
                results = {key: [] for key in results}

    def test_q(self, n_interval, n_tot):

        self.model.eval()
        self.target.eval()

        results = {'loss': [], 'n': [], 'q_diff': [], 'q': [], 'act_diff': [], 'r': [], 'a_best': []}

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda().unsqueeze(1), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(), requires_grad=False)
            r = Variable(sample['r'].float().cuda(), requires_grad=False)
            t = Variable(sample['t'].float().cuda(), requires_grad=False)
            k = Variable(sample['k'].float().cuda(), requires_grad=False)
            f = sample['f'].float().cuda()
            base = sample['base'].float()

            q = self.model(s)
            q_a = q.gather(1, a)[:, 0]
            q_best, a_best = q.data.cpu().max(1)

            q_target = self.target(s_tag)

            max_q_target = Variable(self.q_estimator(s_tag, q_target, a_tag, f), requires_grad=False)
            loss = self.loss_fn(q_a, r + (self.discount ** k) * (max_q_target * (1 - t)))

            if self.myopic:
                q_diff = r.data.cpu() - q_best
            else:
                q_diff = (f.cpu() - base) - q_best

            act_diff = (a.squeeze(1).data.cpu() == a_best).numpy()

            # add results
            results['loss'].append(loss.data.cpu().numpy()[0])
            results['n'].append(n)
            results['q_diff'].append(q_diff)
            results['q'].append(q.data.cpu())
            results['act_diff'].append(act_diff)
            results['r'].append(r.data.cpu())
            results['a_best'].append(a_best)

            if not (n+1) % n_interval:
                results['s'] = s.data.cpu()
                results['s_tag'] = s_tag.data.cpu()
                results['q'] = torch.cat(results['q'])
                results['r'] = torch.cat(results['r'])
                results['q_diff'] = torch.cat(results['q_diff'])
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_best'] = torch.cat(results['a_best'])
                results['n_steps'] = self.test_dataset.n_steps
                yield results
                self.model.eval()
                self.target.eval()
                results = {key: [] for key in results}

    def dummy_play(self, n_interval, n_tot):
        while True:
            yield {'scores': [0] * n_interval, 'epoch': range(n_interval)}

    def single_play(self, n_interval, n_tot):

        player = self.player()
        results = {'scores': [], 'epoch': []}

        for epoch in range(0, n_tot, n_interval):

            params = self.model.state_dict()
            for i in tqdm(range(n_interval)):
                score = player.play(params)
                results['scores'].append(score)
                results['epoch'].append(epoch + i)

            yield results
            results = {key: [] for key in results}

    def multi_play(self, n_interval, n_tot):

        ctx = mp.get_context('forkserver')
        queue = ctx.Queue()
        # new = mp.Event()
        jobs = ctx.Queue(n_interval)
        done = ctx.Event()

        processes = []
        for rank in range(args.gpu_workers):
            p = ctx.Process(target=player_worker, args=(queue, jobs, done, self.player))
            p.start()
            processes.append(p)

        try:

            results = {'scores': [], 'epoch': []}

            for epoch in range(0, n_tot, n_interval):

                params = self.model.state_dict()
                [jobs.put(params) for i in range(n_interval)]

                for i in tqdm(range(n_interval)):
                    score = queue.get()
                    results['scores'].append(score)
                    results['epoch'].append(epoch + i)

                yield results
                results = {key: [] for key in results}

            raise StopIteration

        finally:

            done.set()
            for p in processes:
                p.join()



    # def multi_play(self, n_interval, n_tot):
    #
    #     ctx = mp.get_context('forkserver')
    #     queue = ctx.Queue()
    #     envs = [Env() for i in range(args.gpu_workers)]
    #     # new = mp.Event()
    #     jobs = ctx.Queue(n_interval)
    #     done = ctx.Event()
    #
    #     processes = []
    #     for rank in range(args.gpu_workers):
    #         p = ctx.Process(target=self.player, args=(envs[rank], queue, jobs, done, rank))
    #         p.start()
    #         processes.append(p)
    #
    #     try:
    #
    #         results = {'scores': [], 'epoch': []}
    #
    #         for epoch in range(0, n_tot, n_interval):
    #
    #             params = self.model.state_dict()
    #             [jobs.put(params) for i in range(n_interval)]
    #
    #             for i in tqdm(range(n_interval)):
    #                 score = queue.get()
    #                 results['scores'].append(score)
    #                 results['epoch'].append(epoch + i)
    #
    #             yield results
    #             results = {key: [] for key in results}
    #
    #         raise StopIteration
    #
    #     finally:
    #
    #         done.set()
    #         for p in processes:
    #             p.join()




    # # A working example for new process for each iteration
    # def player(env,  model, queue):
    #
    #     print("I am Here")
    #
    #     env.reset()
    #     while not env.t:
    #
    #         s = Variable(env.s, requires_grad=False)
    #
    #         a = model(s)
    #         a = a.data.cpu().numpy()
    #         a = np.argmax(a)
    #         env.step(a)
    #
    #     queue.put(env.score)

    # # A working example for new process for each iteration pretty slow
    # def play(self, n_interval, n_tot):
    #
    #     queue = mp.Queue()
    #     n = args.cpu_workers
    #     envs = [Env() for i in range(args.cpu_workers)]
    #
    #     for epoch in range(0, n_tot, n):
    #
    #         results = {'scores': [], 'epoch': []}
    #         self.models.cpu().eval()
    #         self.models.share_memory()
    #
    #         processes = []
    #         for rank in range(n):
    #             p = mp.Process(target=player, args=(envs[rank], self.models, queue))
    #             p.start()
    #             processes.append(p)
    #
    #         for i in tqdm(range(n)):
    #             score = queue.get()
    #             results['scores'].append(score)
    #             results['epoch'].append(epoch + i)
    #
    #         for p in processes:
    #             p.join()
    #
    #         self.models.cuda().train()
    #         yield results


# def player_q(env, queue, jobs, done, myid):
#
#     # print("P %d: I am Here" % myid)
#     models = {(0,): DQN, (1,): DQNDueling}
#     Model = models[(args.dueling,)]
#     model = Model(consts.n_actions[args.game])
#     model = model.cuda()
#     model = torch.nn.DataParallel(model)
#     greedy = args.greedy
#     action_space = consts.n_actions[args.game]
#
#     while not done.is_set():
#
#         params = jobs.get()
#         model.load_state_dict(params)
#         # print("P %d: got a job" % myid)
#         env.reset()
#         # print("P %d: passed reset" % myid)
#         softmax = torch.nn.Softmax()
#         choices = np.arange(action_space, dtype=np.int)
#
#         while not env.t:
#
#             s = Variable(env.s, requires_grad=False)
#
#             a = model(s)
#             if greedy:
#                 a = a.data.cpu().numpy()
#                 a = np.argmax(a)
#             else:
#                 a = softmax(a).data.squeeze(0).cpu().numpy()
#                 # print(a)
#                 a = np.random.choice(choices, p=a)
#             env.step(a)
#
#         # print("P %d: finished with score %d" % (myid, env.score))
#         queue.put(env.score)
#
#     env.close()
#
# def player_va(env, queue, jobs, done, myid):
#
#     model = DVAN_ActionOut(consts.n_actions[args.game])
#     model = model.cuda()
#     model = torch.nn.DataParallel(model)
#     greedy = args.greedy
#     action_space = consts.n_actions[args.game]
#
#     while not done.is_set():
#
#         params = jobs.get()
#         model.load_state_dict(params)
#         env.reset()
#         softmax = torch.nn.Softmax()
#         choices = np.arange(action_space, dtype=np.int)
#
#         while not env.t:
#
#             s = Variable(env.s, requires_grad=False)
#
#             v, a = model(s)
#             if greedy:
#                 a = a.data.cpu().numpy()
#                 a = np.argmax(a)
#             else:
#                 a = softmax(a).data.squeeze(0).cpu().numpy()
#                 a = np.random.choice(choices, p=a)
#             env.step(a)
#
#         queue.put(env.score)
#
#     env.close()
#
#
# def player_ava(env, queue, jobs, done, myid):
#
#     model = DVAN_ActionIn(3)
#     model = model.cuda()
#     model.eval()
#     model = torch.nn.DataParallel(model)
#     greedy = args.greedy
#     action_space = consts.action_space
#
#     actions_matrix = actions.unsqueeze(0)
#     actions_matrix = actions_matrix.repeat(1, 1, 1).float()
#
#     while not done.is_set():
#
#         params = jobs.get()
#         model.load_state_dict(params)
#         env.reset()
#         softmax = torch.nn.Softmax()
#         choices = np.arange(action_space, dtype=np.int)
#
#         while not env.t:
#
#             s = Variable(env.s, requires_grad=False)
#
#             v, a = model(s, actions_matrix)
#
#             if greedy:
#                 a = a.data.cpu().numpy()
#                 a = np.argmax(a)
#             else:
#                 a = softmax(a).data.squeeze(0).cpu().numpy()
#                 a = np.random.choice(choices, p=a)
#             env.step(a)
#
#         queue.put(env.score)
#
#     env.close()

    # def episodic_evaluator_va(self):
    #
    #     self.model.eval()
    #     self.target.eval()
    #     results = {'q_diff': [], 'a_agent': [], 'a_player': []}
    #
    #     for n, sample in tqdm(enumerate(self.episodic_loader)):
    #
    #         is_final = sample['is_final']
    #         final_indicator, final_index = is_final.max(0)
    #         final_indicator = final_indicator.numpy()[0]
    #         final_index = final_index.numpy()[0]+1 if final_indicator else self.batch
    #
    #         s = Variable(sample['s'][:final_index].cuda(), requires_grad=False)
    #         a = Variable(sample['a'][:final_index].cuda().unsqueeze(1), requires_grad=False)
    #         f = sample['f'][:final_index].float().cuda()
    #         base = sample['base'][:final_index].float()
    #         r = Variable(sample['r'][:final_index].float().cuda(), requires_grad=False)
    #
    #         value, advantage = self.model(s)
    #
    #         value = value.squeeze(1)
    #
    #         if self.myopic:
    #             v_diff = r.data.cpu() - value.data.cpu()
    #         else:
    #             v_diff = (f.cpu() - base) - value.data.cpu()
    #
    #         a_player = a.squeeze(1).data.cpu()
    #         a_agent = advantage.data.cpu()
    #
    #         # add results
    #         results['q_diff'].append(v_diff)
    #         results['a_agent'].append(a_agent)
    #         results['a_player'].append(a_player)
    #
    #         if final_indicator:
    #             results['q_diff'] = torch.cat(results['q_diff'])
    #             results['a_agent'] = torch.cat(results['a_agent'])
    #             results['a_player'] = torch.cat(results['a_player'])
    #             yield results
    #             self.model.eval()
    #             self.target.eval()
    #             results = {key: [] for key in results}
    #
    # def episodic_evaluator_ava(self):
    #
    #     self.model.eval()
    #     self.target.eval()
    #     results = {'q_diff': [], 'a_agent': [], 'a_player': []}
    #
    #     for n, sample in tqdm(enumerate(self.episodic_loader)):
    #
    #         is_final = sample['is_final']
    #         final_indicator, final_index = is_final.max(0)
    #         final_indicator = final_indicator.numpy()[0]
    #         final_index = final_index.numpy()[0]+1 if final_indicator else self.batch
    #
    #         s = Variable(sample['s'][:final_index].cuda(), requires_grad=False)
    #         a = Variable(sample['a'][:final_index].cuda().unsqueeze(1), requires_grad=False)
    #         f = sample['f'][:final_index].float().cuda()
    #         base = sample['base'][:final_index].float()
    #         r = Variable(sample['r'][:final_index].float().cuda(), requires_grad=False)
    #         a_index = Variable(sample['a_index'][:final_index].cuda().unsqueeze(1), requires_grad=False)
    #
    #         value, advantage = self.model(s)
    #         value = value.squeeze(1)
    #
    #         value, advantage = self.model(s, self.actions_matrix)
    #
    #         if self.myopic:
    #             v_diff = r.data.cpu() - value.data.cpu()
    #         else:
    #             v_diff = (f.cpu() - base) - value.data.cpu()
    #
    #         a_player = a.squeeze(1).data.cpu()
    #         a_agent = advantage.data.cpu()
    #
    #         # add results
    #         results['q_diff'].append(v_diff)
    #         results['a_agent'].append(a_agent)
    #         results['a_player'].append(a_player)
    #
    #         if final_indicator:
    #             results['q_diff'] = torch.cat(results['q_diff'])
    #             results['a_agent'] = torch.cat(results['a_agent'])
    #             results['a_player'] = torch.cat(results['a_player'])
    #             yield results
    #             self.model.eval()
    #             self.target.eval()
    #             results = {key: [] for key in results}
示例#10
0
class BehavioralHotAgent(Agent):
    def __init__(self, load_dataset=True):

        super(BehavioralHotAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset,
                batch_sampler=self.val_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

            self.episodic_loader = torch.utils.data.DataLoader(
                self.full_dataset,
                sampler=self.episodic_sampler,
                batch_size=self.batch,
                num_workers=args.cpu_workers)

        if self.l1_loss:
            self.loss_fn_value = torch.nn.L1Loss(size_average=True)
            self.loss_fn_q = torch.nn.L1Loss(size_average=True)
        else:
            self.loss_fn_value = torch.nn.MSELoss(size_average=True)
            self.loss_fn_q = torch.nn.MSELoss(size_average=True)

        self.loss_fn_r = torch.nn.MSELoss(size_average=True)
        self.loss_fn_p = torch.nn.L1Loss(size_average=True)

        if self.weight_by_expert:
            self.loss_fn_beta = torch.nn.CrossEntropyLoss(reduce=False)
        else:
            self.loss_fn_beta = torch.nn.CrossEntropyLoss(reduce=True)

        # alpha weighted sum

        self.alpha_v = 1  # 1 / 0.02
        self.alpha_b = 1  # 1 / 0.7

        self.alpha_r = 1  # 1 / 0.7
        self.alpha_p = 0  # 1 / 0.7
        self.alpha_q = 1

        self.model = BehavioralHotNet()
        self.model.cuda()

        # configure learning

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer_v = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr)
        self.scheduler_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v, self.decay)

        self.optimizer_beta = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_beta)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        self.optimizer_q = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_q)
        self.scheduler_q = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q, self.decay)

        self.optimizer_r = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_r)
        self.scheduler_r = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_r, self.decay)

        self.optimizer_p = BehavioralHotAgent.set_optimizer(
            self.model.parameters(), args.lr_p)
        self.scheduler_p = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_p, self.decay)

        self.episodic_evaluator = self.dummy_episodic_evaluator

        actions = torch.FloatTensor(consts.hotvec_matrix) / (3**(0.5))
        actions = Variable(actions, requires_grad=False).cuda()

        self.actions_matrix = actions.unsqueeze(0)
        # self.reverse_excitation_index = consts.hotvec_inv

    @staticmethod
    def individual_loss_fn_l2(argument):
        return abs(argument.data.cpu().numpy())**2

    @staticmethod
    def individual_loss_fn_l1(argument):
        return abs(argument.data.cpu().numpy())

    def save_checkpoint(self, path, aux=None):

        cpu_state = self.model.state_dict()
        for k in cpu_state:
            cpu_state[k] = cpu_state[k].cpu()

        state = {
            'state_dict': self.model.state_dict(),
            'state_dict_cpu': cpu_state,
            'optimizer_v_dict': self.optimizer_v.state_dict(),
            'optimizer_beta_dict': self.optimizer_beta.state_dict(),
            'optimizer_p_dict': self.optimizer_p.state_dict(),
            'optimizer_qeta_dict': self.optimizer_q.state_dict(),
            'optimizer_r_dict': self.optimizer_r.state_dict(),
            'aux': aux
        }

        torch.save(state, path)

    def load_checkpoint(self, path):

        if self.cuda:
            state = torch.load(path)
            self.model.load_state_dict(state['state_dict'])
        else:
            state = torch.load(path,
                               map_location=lambda storage, location: storage)
            self.model.load_state_dict(state['state_dict_cpu'])
        self.optimizer_v.load_state_dict(state['optimizer_v_dict'])
        self.optimizer_beta.load_state_dict(state['optimizer_beta_dict'])
        self.optimizer_p.load_state_dict(state['optimizer_p_dict'])
        self.optimizer_q.load_state_dict(state['optimizer_qeta_dict'])
        self.optimizer_r.load_state_dict(state['optimizer_r_dict'])

        return state['aux']

    def resume(self, model_path):

        aux = self.load_checkpoint(model_path)
        # self.update_target()
        return aux

    def update_target(self):
        self.target.load_state_dict(self.model.state_dict())

    def dummy_episodic_evaluator(self):
        while True:
            yield {
                'q_diff': torch.zeros(100),
                'a_agent': torch.zeros(100, self.action_space),
                'a_player': torch.zeros(100).long()
            }

    def _episodic_evaluator(self):
        pass

    def learn(self, n_interval, n_tot):

        self.model.train()
        # self.target.eval()
        results = {
            'n': [],
            'loss_v': [],
            'loss_b': [],
            'loss_q': [],
            'loss_p': [],
            'loss_r': []
        }

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda(), requires_grad=False)
            a_tag = Variable(sample['a_tag'].float().cuda(),
                             requires_grad=False)
            r = Variable(sample['r'].cuda().unsqueeze(1), requires_grad=False)
            t = Variable(sample['t'].cuda().unsqueeze(1), requires_grad=False)
            k = Variable(sample['k'].cuda().unsqueeze(1), requires_grad=False)
            a_index = Variable(sample['a_index'].cuda(), requires_grad=False)
            f = Variable(sample['f'].cuda().unsqueeze(1), requires_grad=False)
            score = Variable(sample['score'].cuda(async=True).unsqueeze(1),
                             requires_grad=False)
            w = Variable(sample['w'].cuda(), requires_grad=False)

            indexes = sample['i']

            value, q, beta, reward, p, phi = self.model(s, a)

            _, _, _, _, _, phi_tag = self.model(s_tag, a_tag)

            loss_v = self.alpha_v * self.loss_fn_value(
                value, f + self.final_score_reward * score)
            loss_q = self.alpha_q * self.loss_fn_q(
                q, f + self.final_score_reward * score)

            if self.weight_by_expert:
                loss_b = self.alpha_b * (self.loss_fn_beta(beta, a_index) *
                                         w).sum() / self.batch
            else:
                loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)

            loss_r = self.alpha_r * self.loss_fn_r(reward, r)

            phi_tag = phi_tag.detach()
            loss_p = self.alpha_p * self.loss_fn_p(p, phi_tag)

            if self.alpha_v:
                self.optimizer_v.zero_grad()
                loss_v.backward(retain_graph=True)
                self.optimizer_v.step()

            if self.alpha_q:
                self.optimizer_q.zero_grad()
                loss_q.backward(retain_graph=True)
                self.optimizer_q.step()

            if self.alpha_b:
                self.optimizer_beta.zero_grad()
                loss_b.backward(retain_graph=True)
                self.optimizer_beta.step()

            if self.alpha_r:
                self.optimizer_r.zero_grad()
                loss_r.backward(retain_graph=True)
                self.optimizer_r.step()

            if self.alpha_p:
                self.optimizer_p.zero_grad()
                loss_p.backward()
                self.optimizer_p.step()

                # param = self.model.fc_z_p.bias.data.cpu().numpy()
                # if np.isnan(param.max()):
                #     print("XXX")
                # paramgrad = self.model.fc_z_p.bias.grad.data.cpu().numpy()
                # param_l = loss_q.data.cpu().numpy()[0]
                # print("max: %g | min: %g | max_grad: %g | min_grad: %g | loss: %g " % (param.max(), param.min(), paramgrad.max(), paramgrad. min(), param_l))

            # add results
            results['loss_q'].append(loss_q.data.cpu().numpy()[0])
            results['loss_v'].append(loss_v.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_r'].append(loss_r.data.cpu().numpy()[0])
            results['loss_p'].append(loss_p.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                # self.update_target()
                self.scheduler_v.step()
                self.scheduler_beta.step()
                self.scheduler_q.step()
                self.scheduler_r.step()
                self.scheduler_p.step()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (
                    n + 1
            ) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            # update a global n_step parameter
            if not (n + 1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step(n + 1)

            if not (n + 1) % n_interval:
                yield results
                self.model.train()
                # self.target.eval()
                results = {key: [] for key in results}

    def test(self, n_interval, n_tot):

        self.model.eval()
        # self.target.eval()

        results = {
            'n': [],
            'loss_v': [],
            'loss_b': [],
            'loss_q': [],
            'loss_p': [],
            'loss_r': [],
            'act_diff': [],
            'a_agent': [],
            'a_player': []
        }

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True),
                             requires_grad=False)
            a = Variable(sample['a'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(async=True).unsqueeze(1),
                             requires_grad=False)
            r = Variable(sample['r'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            t = Variable(sample['t'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)
            score = Variable(sample['score'].cuda(async=True).unsqueeze(1),
                             requires_grad=False)
            w = Variable(sample['w'].cuda(), requires_grad=False)

            f = Variable(sample['f'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            indexes = sample['i']

            value, q, beta, reward, p, phi = self.model(s, a)
            _, _, _, _, _, phi_tag = self.model(s_tag, a_tag)

            q = q.squeeze(1)
            reward = reward.squeeze(1)

            loss_v = self.alpha_v * self.loss_fn_value(
                value, f + self.final_score_reward * score)
            loss_q = self.alpha_q * self.loss_fn_q(
                q, f + self.final_score_reward * score)

            if self.weight_by_expert:
                loss_b = self.alpha_b * (self.loss_fn_beta(beta, a_index) *
                                         w).sum() / self.batch
            else:
                loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)

            loss_r = self.alpha_r * self.loss_fn_r(reward, r)

            phi_tag = Variable(phi_tag.data, requires_grad=False)
            loss_p = self.alpha_p * self.loss_fn_p(p, phi_tag)

            # collect actions statistics
            a_index_np = a_index.data.cpu().numpy()

            _, beta_index = beta.data.cpu().max(1)
            beta_index = beta_index.numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            # add results
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)
            results['loss_q'].append(loss_q.data.cpu().numpy()[0])
            results['loss_v'].append(loss_v.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_r'].append(loss_r.data.cpu().numpy()[0])
            results['loss_p'].append(loss_p.data.cpu().numpy()[0])
            results['n'].append(n)

            if not (n + 1) % n_interval:
                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.model.eval()
                # self.target.eval()
                results = {key: [] for key in results}

    def play_stochastic(self, n_tot):
        raise NotImplementedError
        # self.model.eval()
        # env = Env()
        # render = args.render
        #
        # n_human = 60
        # humans_trajectories = iter(self.data)
        #
        # for i in range(n_tot):
        #
        #     env.reset()
        #
        #     observation = next(humans_trajectories)
        #     print("Observation %s" % observation)
        #     trajectory = self.data[observation]
        #
        #     j = 0
        #
        #     while not env.t:
        #
        #         if j < n_human:
        #             a = trajectory[j, self.meta['action']]
        #
        #         else:
        #
        #             if self.cuda:
        #                 s = Variable(env.s.cuda(), requires_grad=False)
        #             else:
        #                 s = Variable(env.s, requires_grad=False)
        #             _, q, _, _, _, _ = self.model(s, self.actions_matrix)
        #
        #             q = q.squeeze(2)
        #
        #             q = q.data.cpu().numpy()
        #             a = np.argmax(q)
        #
        #         env.step(a)
        #
        #         j += 1
        #
        #     yield {'o': env.s.cpu().numpy(),
        #            'score': env.score}

    def play_episode(self, n_tot):

        self.model.eval()
        env = Env()

        n_human = 120
        humans_trajectories = iter(self.data)
        softmax = torch.nn.Softmax()

        mask = torch.FloatTensor(consts.actions_mask[args.game])
        mask = Variable(mask.cuda(), requires_grad=False)
        # self.actions_matrix = torch.FloatTensor([[0, 0, 0], [1, 0, 0],[0, 1, 0], [0, 0, 1]])

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]
            choices = np.arange(self.global_action_space, dtype=np.int)

            j = 0

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)
                v, q, beta, _, _, phi = self.model(s, self.actions_matrix)
                beta = beta.squeeze(0)
                q = q.squeeze(2)
                q = q.squeeze(0)

                q = q * mask
                # beta[0] = 0
                temp = 0.1
                if True:  # self.imitation:

                    # consider only 3 most frequent actions
                    beta_np = beta.data.cpu().numpy()
                    indices = np.argsort(beta_np)

                    # maskb = Variable(torch.FloatTensor([i in indices[14:18] for i in range(18)]), requires_grad=False)
                    # maskb = Variable(torch.FloatTensor([1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), requires_grad=False)
                    # maskb = maskb.cuda()
                    # pi = maskb * (q / q.max())

                    maskb = Variable(torch.FloatTensor(
                        [1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                         0]),
                                     requires_grad=False)
                    maskb = maskb.cuda()
                    pi = maskb * (beta / beta.max())
                    # pi = maskb * (q / q.max())
                    self.greedy = False

                    # if j%2:
                    #     pi = maskb * (q / q.max())
                    #     self.greedy = True
                    # else:
                    #     self.greedy = False
                    #     pi = maskb * (beta / beta.max())
                    # pi = (beta > 3).float() * (q / q.max())

                    # pi = beta  # (beta > 5).float() * (q / q.max())
                    # pi[0] = 0
                    # beta_prob = softmax(pi)
                    beta_prob = pi
                else:
                    pi = q / q.max()  # q.max() is the temperature
                    beta_prob = q

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    # a = np.random.choice(choices)
                    if self.greedy:
                        a = pi.data.cpu().numpy()
                        a = np.argmax(a)
                    else:
                        a = softmax(pi / temp).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                env.step(a)

                # x = phi.squeeze(0).data.cpu().numpy()
                # print(np.mean(abs(x)))
                # yield v, q, beta, r, p, s
                yield {
                    'o': env.s.cpu().numpy(),
                    'v': v.data.cpu().numpy(),
                    's': phi.data.cpu().numpy(),
                    'score': env.score,
                    'beta': beta_prob.data.cpu().numpy(),
                    'phi': phi.squeeze(0).data.cpu().numpy()
                }

                j += 1

        raise StopIteration
示例#11
0
    def __init__(self, load_dataset=True):

        super(BehavioralAgent, self).__init__()

        self.actions_transform = np.array(consts.action2activation[args.game])

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset,
                batch_sampler=self.val_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

            self.episodic_loader = torch.utils.data.DataLoader(
                self.full_dataset,
                sampler=self.episodic_sampler,
                batch_size=self.batch,
                num_workers=args.cpu_workers)

        if self.l1_loss:
            self.loss_fn_value = torch.nn.L1Loss(size_average=True)
            self.individual_loss_fn_value = self.individual_loss_fn_l1
        else:
            self.loss_fn_value = torch.nn.MSELoss(size_average=True)
            self.individual_loss_fn_value = self.individual_loss_fn_l2

        self.loss_fn_r = torch.nn.MSELoss(size_average=True)
        self.individual_loss_fn_r = self.individual_loss_fn_l2

        self.loss_fn_q = torch.nn.L1Loss(size_average=True)
        self.individual_loss_fn_q = self.individual_loss_fn_l1

        self.loss_fn_p = torch.nn.L1Loss(size_average=True)
        self.individual_loss_fn_p = self.individual_loss_fn_l1

        # self.target_single = BehavioralNet(self.global_action_space)

        # alpha weighted sum

        self.alpha_v = 1  # 1 / 0.02
        self.alpha_b = 1  # 1 / 0.7

        self.alpha_r = 1  # 1 / 0.7
        self.alpha_p = 1  # 1 / 0.7
        self.alpha_q = 1

        if args.deterministic:  # 1 / 0.02
            self.loss_fn_beta = torch.nn.L1Loss(size_average=True)
            self.learn = self.learn_deterministic
            self.test = self.test_deterministic
            self.play = self.play_deterministic
            self.play_episode = self.play_episode_deterministic
            self.model_single = BehavioralNetDeterministic(
                self.global_action_space)

        else:
            self.loss_fn_beta = torch.nn.CrossEntropyLoss()
            self.learn = self.learn_stochastic
            self.test = self.test_stochastic
            self.play = self.play_stochastic
            self.play_episode = self.play_episode_stochastic
            self.model_single = BehavioralNet(self.global_action_space)

        # configure learning

        if self.cuda:
            self.model_single = self.model_single.cuda()
            # self.model = torch.nn.DataParallel(self.model_single)
            self.model = self.model_single
            # self.target_single = self.target_single.cuda()
            # self.target = torch.nn.DataParallel(self.target_single)
        else:
            self.model = self.model_single
            # self.target = self.target_single

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer_v = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr)
        self.scheduler_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v, self.decay)

        self.optimizer_beta = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_beta)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        self.optimizer_q = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_q)
        self.scheduler_q = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q, self.decay)

        self.optimizer_r = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_r)
        self.scheduler_r = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_r, self.decay)

        self.optimizer_p = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_p)
        self.scheduler_p = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_p, self.decay)

        self.episodic_evaluator = self.dummy_episodic_evaluator

        # build the action matrix
        # excitation = torch.LongTensor(consts.game_excitation_map[args.game])
        excitation = torch.LongTensor(consts.excitation_map)
        mask = torch.LongTensor(consts.excitation_mask[args.game])
        mask_dup = mask.unsqueeze(0).repeat(consts.action_space, 1)
        actions = Variable(mask_dup * excitation, requires_grad=False)
        actions = Variable(excitation, requires_grad=False)
        if args.cuda:
            actions = actions.cuda()

        self.actions_matrix = actions.unsqueeze(0)
        self.actions_matrix = self.actions_matrix.repeat(1, 1, 1).float()

        self.go_to_max = np.inf  # 4096 * 8 * 4

        self.reverse_excitation_index = consts.reverse_excitation_index
示例#12
0
class BehavioralAgent(Agent):
    def __init__(self, load_dataset=True):

        super(BehavioralAgent, self).__init__()

        self.actions_transform = np.array(consts.action2activation[args.game])

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset,
                batch_sampler=self.val_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

            self.episodic_loader = torch.utils.data.DataLoader(
                self.full_dataset,
                sampler=self.episodic_sampler,
                batch_size=self.batch,
                num_workers=args.cpu_workers)

        if self.l1_loss:
            self.loss_fn_value = torch.nn.L1Loss(size_average=True)
            self.individual_loss_fn_value = self.individual_loss_fn_l1
        else:
            self.loss_fn_value = torch.nn.MSELoss(size_average=True)
            self.individual_loss_fn_value = self.individual_loss_fn_l2

        self.loss_fn_r = torch.nn.MSELoss(size_average=True)
        self.individual_loss_fn_r = self.individual_loss_fn_l2

        self.loss_fn_q = torch.nn.L1Loss(size_average=True)
        self.individual_loss_fn_q = self.individual_loss_fn_l1

        self.loss_fn_p = torch.nn.L1Loss(size_average=True)
        self.individual_loss_fn_p = self.individual_loss_fn_l1

        # self.target_single = BehavioralNet(self.global_action_space)

        # alpha weighted sum

        self.alpha_v = 1  # 1 / 0.02
        self.alpha_b = 1  # 1 / 0.7

        self.alpha_r = 1  # 1 / 0.7
        self.alpha_p = 1  # 1 / 0.7
        self.alpha_q = 1

        if args.deterministic:  # 1 / 0.02
            self.loss_fn_beta = torch.nn.L1Loss(size_average=True)
            self.learn = self.learn_deterministic
            self.test = self.test_deterministic
            self.play = self.play_deterministic
            self.play_episode = self.play_episode_deterministic
            self.model_single = BehavioralNetDeterministic(
                self.global_action_space)

        else:
            self.loss_fn_beta = torch.nn.CrossEntropyLoss()
            self.learn = self.learn_stochastic
            self.test = self.test_stochastic
            self.play = self.play_stochastic
            self.play_episode = self.play_episode_stochastic
            self.model_single = BehavioralNet(self.global_action_space)

        # configure learning

        if self.cuda:
            self.model_single = self.model_single.cuda()
            # self.model = torch.nn.DataParallel(self.model_single)
            self.model = self.model_single
            # self.target_single = self.target_single.cuda()
            # self.target = torch.nn.DataParallel(self.target_single)
        else:
            self.model = self.model_single
            # self.target = self.target_single

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER
        self.optimizer_v = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr)
        self.scheduler_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v, self.decay)

        self.optimizer_beta = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_beta)
        self.scheduler_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta, self.decay)

        self.optimizer_q = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_q)
        self.scheduler_q = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q, self.decay)

        self.optimizer_r = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_r)
        self.scheduler_r = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_r, self.decay)

        self.optimizer_p = BehavioralAgent.set_optimizer(
            self.model.parameters(), args.lr_p)
        self.scheduler_p = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_p, self.decay)

        self.episodic_evaluator = self.dummy_episodic_evaluator

        # build the action matrix
        # excitation = torch.LongTensor(consts.game_excitation_map[args.game])
        excitation = torch.LongTensor(consts.excitation_map)
        mask = torch.LongTensor(consts.excitation_mask[args.game])
        mask_dup = mask.unsqueeze(0).repeat(consts.action_space, 1)
        actions = Variable(mask_dup * excitation, requires_grad=False)
        actions = Variable(excitation, requires_grad=False)
        if args.cuda:
            actions = actions.cuda()

        self.actions_matrix = actions.unsqueeze(0)
        self.actions_matrix = self.actions_matrix.repeat(1, 1, 1).float()

        self.go_to_max = np.inf  # 4096 * 8 * 4

        self.reverse_excitation_index = consts.reverse_excitation_index

    @staticmethod
    def individual_loss_fn_l2(argument):
        return abs(argument.data.cpu().numpy())**2

    @staticmethod
    def individual_loss_fn_l1(argument):
        return abs(argument.data.cpu().numpy())

    def save_checkpoint(self, path, aux=None):

        cpu_state = self.model.state_dict()
        for k in cpu_state:
            cpu_state[k] = cpu_state[k].cpu()

        state = {
            'state_dict': self.model.state_dict(),
            'state_dict_cpu': cpu_state,
            'optimizer_v_dict': self.optimizer_v.state_dict(),
            'optimizer_beta_dict': self.optimizer_beta.state_dict(),
            'optimizer_p_dict': self.optimizer_p.state_dict(),
            'optimizer_qeta_dict': self.optimizer_q.state_dict(),
            'optimizer_r_dict': self.optimizer_r.state_dict(),
            'aux': aux
        }

        torch.save(state, path)

    def load_checkpoint(self, path):

        if self.cuda:
            state = torch.load(path)
            self.model.load_state_dict(state['state_dict'])
        else:
            state = torch.load(path,
                               map_location=lambda storage, location: storage)
            self.model.load_state_dict(state['state_dict_cpu'])
        self.optimizer_v.load_state_dict(state['optimizer_v_dict'])
        self.optimizer_beta.load_state_dict(state['optimizer_beta_dict'])
        self.optimizer_p.load_state_dict(state['optimizer_p_dict'])
        self.optimizer_q.load_state_dict(state['optimizer_qeta_dict'])
        self.optimizer_r.load_state_dict(state['optimizer_r_dict'])

        return state['aux']

    def resume(self, model_path):

        aux = self.load_checkpoint(model_path)
        # self.update_target()
        return aux

    def update_target(self):
        self.target.load_state_dict(self.model.state_dict())

    def dummy_episodic_evaluator(self):
        while True:
            yield {
                'q_diff': torch.zeros(100),
                'a_agent': torch.zeros(100, self.action_space),
                'a_player': torch.zeros(100).long()
            }

    def _episodic_evaluator(self):
        pass

    def learn_stochastic(self, n_interval, n_tot):

        self.model.train()
        # self.target.eval()
        results = {
            'n': [],
            'loss_v': [],
            'loss_b': [],
            'loss_q': [],
            'loss_p': [],
            'loss_r': []
        }

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(), requires_grad=False)
            a = Variable(sample['a'].float().cuda(), requires_grad=False)
            a_tag = Variable(sample['a_tag'].float().cuda(),
                             requires_grad=False)
            r = Variable(sample['r'].float().cuda().unsqueeze(1),
                         requires_grad=False)
            t = Variable(sample['t'].float().cuda().unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].float().cuda().unsqueeze(1),
                         requires_grad=False)
            a_index = Variable(sample['a_index'].cuda(), requires_grad=False)
            f = Variable(sample['f'].float().cuda().unsqueeze(1),
                         requires_grad=False)
            indexes = sample['i']

            value, q, beta, reward, p, phi = self.model(s, a)

            _, _, _, _, _, phi_tag = self.model(s_tag, a_tag)

            # m = (((f - value) > 0).float() + 1) if n > self.go_to_max else Variable(torch.ones(f.data.shape).cuda())
            # m = m.detach()

            loss_v = self.alpha_v * self.loss_fn_value(value, f)
            loss_q = self.alpha_q * self.loss_fn_q(q, f)

            loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)

            loss_r = self.alpha_r * self.loss_fn_r(reward, r)

            phi_tag = phi_tag.detach()
            loss_p = self.alpha_p * self.loss_fn_p(p, phi_tag)

            #
            # # calculate the td error for the priority replay
            # if self.prioritized_replay:
            #     argument = r + (self.discount ** k) * (max_q_target * (1 - t)) - q_a
            #     individual_loss = LfdAgent.individual_loss_fn(argument)
            #     self.train_dataset.update_td_error(indexes.numpy(), individual_loss)
            # self.model.module.conv1.weight
            if self.alpha_v:
                self.optimizer_v.zero_grad()
                loss_v.backward(retain_graph=True)
                self.optimizer_v.step()

            if self.alpha_q:
                self.optimizer_q.zero_grad()
                loss_q.backward(retain_graph=True)
                self.optimizer_q.step()

            if self.alpha_b:
                self.optimizer_beta.zero_grad()
                loss_b.backward(retain_graph=True)
                self.optimizer_beta.step()

            if self.alpha_r:
                self.optimizer_r.zero_grad()
                loss_r.backward(retain_graph=True)
                self.optimizer_r.step()

            if self.alpha_p:
                self.optimizer_p.zero_grad()
                loss_p.backward()
                self.optimizer_p.step()

            # add results
            results['loss_q'].append(loss_q.data.cpu().numpy()[0])
            results['loss_v'].append(loss_v.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_r'].append(loss_r.data.cpu().numpy()[0])
            results['loss_p'].append(loss_p.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                # self.update_target()
                self.scheduler_v.step()
                self.scheduler_beta.step()
                self.scheduler_q.step()
                self.scheduler_r.step()
                self.scheduler_p.step()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (
                    n + 1
            ) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            # update a global n_step parameter
            if not (n + 1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step(n + 1)

            if not (n + 1) % n_interval:
                yield results
                self.model.train()
                # self.target.eval()
                results = {key: [] for key in results}

    def test_stochastic(self, n_interval, n_tot):

        self.model.eval()
        # self.target.eval()

        results = {
            'n': [],
            'loss_v': [],
            'loss_b': [],
            'loss_q': [],
            'loss_p': [],
            'loss_r': [],
            'act_diff': [],
            'a_agent': [],
            'a_player': []
        }

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True),
                             requires_grad=False)
            a = Variable(sample['a'].cuda(async=True).float().unsqueeze(1),
                         requires_grad=False)
            a_tag = Variable(
                sample['a_tag'].cuda(async=True).float().unsqueeze(1),
                requires_grad=False)
            r = Variable(sample['r'].cuda(async=True).float().unsqueeze(1),
                         requires_grad=False)
            t = Variable(sample['t'].cuda(async=True).float().unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].cuda(async=True).float().unsqueeze(1),
                         requires_grad=False)
            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)
            f = Variable(sample['f'].cuda(async=True).float().unsqueeze(1),
                         requires_grad=False)
            indexes = sample['i']

            value, q, beta, reward, p, phi = self.model(s, a)
            _, _, _, _, _, phi_tag = self.model(s_tag, a_tag)

            q = q.squeeze(1)
            # m = (((f - value) > 0).float() + 1) if n > self.go_to_max else Variable(torch.ones(f.data.shape).cuda())
            # m = m.detach()

            loss_v = self.alpha_v * self.loss_fn_value(value, f)
            loss_q = self.alpha_q * self.loss_fn_q(q, f)

            # zerovar = Variable(torch.zeros(f.data.shape).cuda(), requires_grad=False)
            # if n > self.go_to_max:
            #     loss_v = self.alpha_v * self.loss_fn_value(F.relu(f - value), zerovar)
            #     loss_q = self.alpha_q * self.loss_fn_q(F.relu(f - q), zerovar)
            # else:
            #     loss_v = self.alpha_v * self.loss_fn_value(value - f, zerovar)
            #     loss_q = self.alpha_q * self.loss_fn_q(q - f, zerovar)
            loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)

            loss_r = self.alpha_r * self.loss_fn_r(reward, r)

            phi_tag = Variable(phi_tag.data, requires_grad=False)
            loss_p = self.alpha_p * self.loss_fn_p(p, phi_tag)

            # collect actions statistics
            a_index_np = a_index.data.cpu().numpy()

            _, beta_index = beta.data.cpu().max(1)
            beta_index = beta_index.numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            # add results
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)
            results['loss_q'].append(loss_q.data.cpu().numpy()[0])
            results['loss_v'].append(loss_v.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_r'].append(loss_r.data.cpu().numpy()[0])
            results['loss_p'].append(loss_p.data.cpu().numpy()[0])
            results['n'].append(n)

            if not (n + 1) % n_interval:
                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.model.eval()
                # self.target.eval()
                results = {key: [] for key in results}

    # def get_action_index(self, a):
    #     m = np.zeros((self.global_action_space, self.skip))
    #     m[a, range(self.skip)] = 1
    #     m = m.sum(1)
    #     a = (1 + np.argmax(m[1:])) * (a.sum() != 0)
    #     # transform a to a valid activation
    #     a = self.actions_transform[a]
    #     return a, a

    def learn_deterministic(self, n_interval, n_tot):

        self.model.train()
        # self.target.eval()
        results = {
            'n': [],
            'loss_v': [],
            'loss_b': [],
            'loss_q': [],
            'loss_p': [],
            'loss_r': []
        }

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True),
                             requires_grad=False)
            a = Variable(sample['a'].cuda(async=True), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(async=True),
                             requires_grad=False)
            r = Variable(sample['r'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            t = Variable(sample['t'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)
            f = Variable(sample['f'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            indexes = sample['i']

            value, q, beta, reward, p, phi = self.model(s, a)
            _, _, _, _, _, phi_tag = self.model(s_tag, a_tag)

            m = (((f - value) > 0).float() +
                 1) if n > self.go_to_max else Variable(
                     torch.ones(f.data.shape).cuda())
            m = m.detach()

            loss_v = self.alpha_v * self.loss_fn_value(value * m, f * m)
            loss_q = self.alpha_q * self.loss_fn_q(q * m, f * m)
            loss_b = self.alpha_b * self.loss_fn_beta(beta * m.repeat(1, 3),
                                                      a * m.repeat(1, 3))

            loss_r = self.alpha_r * self.loss_fn_r(reward, r)

            phi_tag = Variable(phi_tag.data, requires_grad=False)
            loss_p = self.alpha_p * self.loss_fn_p(p, phi_tag)

            if self.alpha_v:
                self.optimizer_v.zero_grad()
                loss_v.backward(retain_graph=True)
                self.optimizer_v.step()

            if self.alpha_q:
                self.optimizer_q.zero_grad()
                loss_q.backward(retain_graph=True)
                self.optimizer_q.step()

            if self.alpha_b:
                self.optimizer_beta.zero_grad()
                loss_b.backward(retain_graph=True)
                self.optimizer_beta.step()

            if self.alpha_r:
                self.optimizer_r.zero_grad()
                loss_r.backward(retain_graph=True)
                self.optimizer_r.step()

            if self.alpha_p:
                self.optimizer_p.zero_grad()
                loss_p.backward()
                self.optimizer_p.step()

            # add results
            results['loss_q'].append(loss_q.data.cpu().numpy()[0])
            results['loss_v'].append(loss_v.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_r'].append(loss_r.data.cpu().numpy()[0])
            results['loss_p'].append(loss_p.data.cpu().numpy()[0])
            results['n'].append(n)

            if not n % self.update_target_interval:
                # self.update_target()
                self.scheduler_v.step()
                self.scheduler_beta.step()
                self.scheduler_q.step()
                self.scheduler_r.step()
                self.scheduler_p.step()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (
                    n + 1
            ) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            # update a global n_step parameter
            if not (n + 1) % self.update_n_steps_interval:
                self.train_dataset.update_n_step(n + 1)

            if not (n + 1) % n_interval:
                yield results
                self.model.train()
                # self.target.eval()
                results = {key: [] for key in results}

    def test_deterministic(self, n_interval, n_tot):

        self.model.eval()
        # self.target.eval()

        results = {
            'n': [],
            'loss_v': [],
            'loss_b': [],
            'loss_q': [],
            'loss_p': [],
            'loss_r': [],
            'act_diff': [],
            'a_agent': [],
            'a_player': []
        }

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(async=True), requires_grad=False)
            s_tag = Variable(sample['s_tag'].cuda(async=True),
                             requires_grad=False)
            a = Variable(sample['a'].cuda(async=True), requires_grad=False)
            a_tag = Variable(sample['a_tag'].cuda(async=True),
                             requires_grad=False)
            r = Variable(sample['r'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            t = Variable(sample['t'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            k = Variable(sample['k'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)
            f = Variable(sample['f'].cuda(async=True).unsqueeze(1),
                         requires_grad=False)
            indexes = sample['i']

            value, q, beta, reward, p, phi = self.model(s, a)
            _, _, _, _, _, phi_tag = self.model(s_tag, a_tag)

            m = (((f - value) > 0).float() +
                 1) if n > self.go_to_max else Variable(
                     torch.ones(f.data.shape).cuda())
            m = m.detach()

            loss_v = self.alpha_v * self.loss_fn_value(value * m, f * m)
            loss_q = self.alpha_q * self.loss_fn_q(q * m, f * m)

            loss_b = self.alpha_b * self.loss_fn_beta(beta * m.repeat(1, 3),
                                                      a * m.repeat(1, 3))

            loss_r = self.alpha_r * self.loss_fn_r(reward, r)

            phi_tag = Variable(phi_tag.data, requires_grad=False)
            loss_p = self.alpha_p * self.loss_fn_p(p, phi_tag)

            # calculate action imitation statistics
            beta_index = (beta.sign().int() *
                          (beta.abs() > 0.5).int()).data.cpu().numpy()
            beta_index[:, 0] = abs(beta_index[:, 0])
            beta_index = np.array(
                [self.reverse_excitation_index[tuple(i)] for i in beta_index])
            a_index_np = a_index.data.cpu().numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            # add results
            results['loss_q'].append(loss_q.data.cpu().numpy()[0])
            results['loss_v'].append(loss_v.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_r'].append(loss_r.data.cpu().numpy()[0])
            results['loss_p'].append(loss_p.data.cpu().numpy()[0])
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)

            results['n'].append(n)

            if not (n + 1) % n_interval:
                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.model.eval()
                self.model.eval()
                # self.target.eval()
                results = {key: [] for key in results}

    def play_stochastic(self, n_tot):
        raise NotImplementedError
        # self.model.eval()
        # env = Env()
        # render = args.render
        #
        # n_human = 60
        # humans_trajectories = iter(self.data)
        #
        # for i in range(n_tot):
        #
        #     env.reset()
        #
        #     observation = next(humans_trajectories)
        #     print("Observation %s" % observation)
        #     trajectory = self.data[observation]
        #
        #     j = 0
        #
        #     while not env.t:
        #
        #         if j < n_human:
        #             a = trajectory[j, self.meta['action']]
        #
        #         else:
        #
        #             if self.cuda:
        #                 s = Variable(env.s.cuda(), requires_grad=False)
        #             else:
        #                 s = Variable(env.s, requires_grad=False)
        #             _, q, _, _, _, _ = self.model(s, self.actions_matrix)
        #
        #             q = q.squeeze(2)
        #
        #             q = q.data.cpu().numpy()
        #             a = np.argmax(q)
        #
        #         env.step(a)
        #
        #         j += 1
        #
        #     yield {'o': env.s.cpu().numpy(),
        #            'score': env.score}

    def play_episode_stochastic(self, n_tot):

        self.model.eval()
        env = Env()

        n_human = 300
        humans_trajectories = iter(self.data)
        softmax = torch.nn.Softmax()

        # self.actions_matrix = torch.FloatTensor([[0, 0, 0], [1, 0, 0],[0, 1, 0], [0, 0, 1]])

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]
            choices = np.arange(self.global_action_space, dtype=np.int)

            j = 0

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)
                v, q, beta, _, _, phi = self.model(s, self.actions_matrix)
                beta = beta.squeeze(0)
                q = q.squeeze(2)
                q = q.squeeze(0)
                # beta[0] = 0

                if self.imitation:
                    pi = (beta > 5).float() * (q / q.max())
                else:
                    pi = q / q.max()  # q.max() is the temperature

                beta_prob = softmax(pi)

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    # a = np.random.choice(choices)
                    if self.greedy:
                        a = pi.data.cpu().numpy()
                        a = np.argmax(a)
                    else:
                        a = softmax(pi).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                env.step(a)

                # x = phi.squeeze(0).data.cpu().numpy()
                # print(np.mean(abs(x)))
                # yield v, q, beta, r, p, s
                yield {
                    'o': env.s.cpu().numpy(),
                    'v': v.data.cpu().numpy(),
                    's': phi.data.cpu().numpy(),
                    'score': env.score,
                    'beta': beta_prob.data.cpu().numpy(),
                    'phi': phi.squeeze(0).data.cpu().numpy()
                }

                j += 1

        raise StopIteration

    def play_episode_deterministic(self, n_tot):
        self.model.eval()
        env = Env()

        n_human = 300
        humans_trajectories = iter(self.data)
        reverse_excitation_index = consts.reverse_excitation_index

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]

            j = 0

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)
                v, q, beta, r, p, phi = self.model(s)
                beta = beta.squeeze(0)

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:

                    beta_index = (beta.sign().int() *
                                  (beta.abs() > 0.5).int()).data.cpu().numpy()
                    beta_index[0] = abs(beta_index[0])
                    a = reverse_excitation_index[tuple(beta_index.data)]

                env.step(a)

                # x = phi.squeeze(0).data.cpu().numpy()
                # print(np.mean(abs(x)))
                # yield v, q, beta, r, p, s
                yield {
                    'o': env.s.cpu().numpy(),
                    'v': v.data.cpu().numpy(),
                    's': phi.data.cpu().numpy(),
                    'score': env.score,
                    'beta': beta.data.cpu().numpy(),
                    'phi': phi.squeeze(0).data.cpu().numpy()
                }

                j += 1

        raise StopIteration

    def play_deterministic(self, n_tot):

        self.model.eval()
        env = Env()
        render = args.render

        n_human = 60
        humans_trajectories = iter(self.data)
        reverse_excitation_index = consts.reverse_excitation_index

        for i in range(n_tot):

            env.reset()

            observation = next(humans_trajectories)
            print("Observation %s" % observation)
            trajectory = self.data[observation]

            j = 0

            while not env.t:

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:

                    if self.cuda:
                        s = Variable(env.s.cuda(), requires_grad=False)
                    else:
                        s = Variable(env.s, requires_grad=False)
                    _, _, beta, _, _, _ = self.model(s)

                    beta = beta.squeeze(0)
                    beta = (beta.sign().int() * (beta.abs() > 0.5).int()).data
                    if self.cuda:
                        beta = beta.cpu().numpy()
                    else:
                        beta = beta.numpy()
                    beta[0] = abs(beta[0])
                    a = reverse_excitation_index[tuple(beta)]

                env.step(a)

                j += 1

            yield {'o': env.s.cpu().numpy(), 'score': env.score}
    def __init__(self, load_dataset=True):

        super(BehavioralEmbeddedAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

        self.loss_v_beta = torch.nn.KLDivLoss()
        self.loss_q_beta = torch.nn.KLDivLoss()

        self.loss_v_pi = torch.nn.KLDivLoss()
        self.loss_q_pi = torch.nn.KLDivLoss()

        self.histogram = torch.from_numpy(self.meta['histogram']).float()

        w_f, w_v, w_h = calc_hist_weights(self.histogram)

        w_f = torch.clamp(w_f, 0, 10).cuda()
        w_v = torch.clamp(w_v, 0, 10).cuda()
        w_h = torch.clamp(w_h, 0, 10).cuda()

        self.loss_beta_f = torch.nn.CrossEntropyLoss(size_average=True,
                                                     weight=w_f)
        self.loss_beta_v = torch.nn.CrossEntropyLoss(size_average=True,
                                                     weight=w_v)
        self.loss_beta_h = torch.nn.CrossEntropyLoss(size_average=True,
                                                     weight=w_h)

        self.loss_pi_f = torch.nn.CrossEntropyLoss(size_average=False)
        self.loss_pi_v = torch.nn.CrossEntropyLoss(size_average=False)
        self.loss_pi_h = torch.nn.CrossEntropyLoss(size_average=False)

        self.behavioral_model = BehavioralDistEmbedding()
        self.behavioral_model.cuda()

        # actor critic setting

        self.actor_critic_model = ActorCritic()
        self.actor_critic_model.cuda()

        self.actor_critic_target = ActorCritic()
        self.actor_critic_target.cuda()

        # configure learning

        cnn_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "cnn" in p[0]
        ]
        emb_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "emb" in p[0]
        ]

        v_beta_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_v" in p[0]
        ]
        a_beta_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_adv" in p[0]
        ]

        beta_f_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_beta_f" in p[0]
        ]
        beta_v_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_beta_v" in p[0]
        ]
        beta_h_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_beta_h" in p[0]
        ]

        v_pi_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "critic_v" in p[0]
        ]
        a_pi_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "critic_adv" in p[0]
        ]

        pi_f_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "fc_actor_f" in p[0]
        ]
        pi_v_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "fc_actor_v" in p[0]
        ]
        pi_h_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "fc_actor_h" in p[0]
        ]

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER

        self.optimizer_critic_v = BehavioralEmbeddedAgent.set_optimizer(
            v_pi_params, 0.0008)
        self.scheduler_critic_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_critic_v, self.decay)

        self.optimizer_critic_q = BehavioralEmbeddedAgent.set_optimizer(
            v_pi_params + a_pi_params, 0.0008)
        self.scheduler_critic_q = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_critic_q, self.decay)

        self.optimizer_v_beta = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + v_beta_params, 0.0008)
        self.scheduler_v_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v_beta, self.decay)

        self.optimizer_q_beta = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + v_beta_params + a_beta_params, 0.0008)
        self.scheduler_q_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q_beta, self.decay)

        self.optimizer_beta_f = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + beta_f_params, 0.0008)
        self.scheduler_beta_f = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta_f, self.decay)

        self.optimizer_beta_v = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + beta_v_params, 0.0008)
        self.scheduler_beta_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta_v, self.decay)

        self.optimizer_beta_h = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + beta_h_params, 0.0008)
        self.scheduler_beta_h = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta_h, self.decay)

        self.optimizer_pi_f = BehavioralEmbeddedAgent.set_optimizer(
            pi_f_params, 0.0008)
        self.scheduler_pi_f = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_f, self.decay)

        self.optimizer_pi_v = BehavioralEmbeddedAgent.set_optimizer(
            pi_v_params, 0.0008)
        self.scheduler_pi_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_v, self.decay)

        self.optimizer_pi_h = BehavioralEmbeddedAgent.set_optimizer(
            pi_h_params, 0.0008)
        self.scheduler_pi_h = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_h, self.decay)

        actions = torch.LongTensor(consts.hotvec_matrix).cuda()
        self.actions_matrix = actions.unsqueeze(0)

        self.q_bins = consts.q_bins[args.game][:-1] / self.meta['avg_score']
        # the long bins are already normalized
        self.v_bins = consts.v_bins[args.game][:-1] / self.meta['avg_score']

        self.q_bins_torch = Variable(torch.from_numpy(
            consts.q_bins[args.game] / self.meta['avg_score']),
                                     requires_grad=False).cuda()
        self.v_bins_torch = Variable(torch.from_numpy(
            consts.v_bins[args.game] / self.meta['avg_score']),
                                     requires_grad=False).cuda()

        self.batch_range = np.arange(self.batch)

        self.zero = Variable(torch.zeros(1))
class BehavioralEmbeddedAgent(Agent):
    def __init__(self, load_dataset=True):

        super(BehavioralEmbeddedAgent, self).__init__()

        self.meta, self.data = preprocess_demonstrations()

        if load_dataset:
            # demonstration source
            self.meta = divide_dataset(self.meta)

            # datasets
            self.train_dataset = DemonstrationMemory("train", self.meta,
                                                     self.data)
            self.val_dataset = DemonstrationMemory("val", self.meta, self.data)
            self.test_dataset = DemonstrationMemory("test", self.meta,
                                                    self.data)
            self.full_dataset = DemonstrationMemory("full", self.meta,
                                                    self.data)

            self.train_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                           train=True)
            self.val_sampler = DemonstrationBatchSampler(self.train_dataset,
                                                         train=False)
            self.test_sampler = DemonstrationBatchSampler(self.test_dataset,
                                                          train=False)
            self.episodic_sampler = SequentialDemonstrationSampler(
                self.full_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset,
                batch_sampler=self.train_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)
            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset,
                batch_sampler=self.test_sampler,
                num_workers=args.cpu_workers,
                pin_memory=True,
                drop_last=False)

        self.loss_v_beta = torch.nn.KLDivLoss()
        self.loss_q_beta = torch.nn.KLDivLoss()

        self.loss_v_pi = torch.nn.KLDivLoss()
        self.loss_q_pi = torch.nn.KLDivLoss()

        self.histogram = torch.from_numpy(self.meta['histogram']).float()

        w_f, w_v, w_h = calc_hist_weights(self.histogram)

        w_f = torch.clamp(w_f, 0, 10).cuda()
        w_v = torch.clamp(w_v, 0, 10).cuda()
        w_h = torch.clamp(w_h, 0, 10).cuda()

        self.loss_beta_f = torch.nn.CrossEntropyLoss(size_average=True,
                                                     weight=w_f)
        self.loss_beta_v = torch.nn.CrossEntropyLoss(size_average=True,
                                                     weight=w_v)
        self.loss_beta_h = torch.nn.CrossEntropyLoss(size_average=True,
                                                     weight=w_h)

        self.loss_pi_f = torch.nn.CrossEntropyLoss(size_average=False)
        self.loss_pi_v = torch.nn.CrossEntropyLoss(size_average=False)
        self.loss_pi_h = torch.nn.CrossEntropyLoss(size_average=False)

        self.behavioral_model = BehavioralDistEmbedding()
        self.behavioral_model.cuda()

        # actor critic setting

        self.actor_critic_model = ActorCritic()
        self.actor_critic_model.cuda()

        self.actor_critic_target = ActorCritic()
        self.actor_critic_target.cuda()

        # configure learning

        cnn_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "cnn" in p[0]
        ]
        emb_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "emb" in p[0]
        ]

        v_beta_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_v" in p[0]
        ]
        a_beta_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_adv" in p[0]
        ]

        beta_f_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_beta_f" in p[0]
        ]
        beta_v_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_beta_v" in p[0]
        ]
        beta_h_params = [
            p[1] for p in self.behavioral_model.named_parameters()
            if "fc_beta_h" in p[0]
        ]

        v_pi_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "critic_v" in p[0]
        ]
        a_pi_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "critic_adv" in p[0]
        ]

        pi_f_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "fc_actor_f" in p[0]
        ]
        pi_v_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "fc_actor_v" in p[0]
        ]
        pi_h_params = [
            p[1] for p in self.actor_critic_model.named_parameters()
            if "fc_actor_h" in p[0]
        ]

        # IT IS IMPORTANT TO ASSIGN MODEL TO CUDA/PARALLEL BEFORE DEFINING OPTIMIZER

        self.optimizer_critic_v = BehavioralEmbeddedAgent.set_optimizer(
            v_pi_params, 0.0008)
        self.scheduler_critic_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_critic_v, self.decay)

        self.optimizer_critic_q = BehavioralEmbeddedAgent.set_optimizer(
            v_pi_params + a_pi_params, 0.0008)
        self.scheduler_critic_q = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_critic_q, self.decay)

        self.optimizer_v_beta = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + v_beta_params, 0.0008)
        self.scheduler_v_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_v_beta, self.decay)

        self.optimizer_q_beta = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + v_beta_params + a_beta_params, 0.0008)
        self.scheduler_q_beta = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_q_beta, self.decay)

        self.optimizer_beta_f = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + beta_f_params, 0.0008)
        self.scheduler_beta_f = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta_f, self.decay)

        self.optimizer_beta_v = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + beta_v_params, 0.0008)
        self.scheduler_beta_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta_v, self.decay)

        self.optimizer_beta_h = BehavioralEmbeddedAgent.set_optimizer(
            cnn_params + emb_params + beta_h_params, 0.0008)
        self.scheduler_beta_h = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_beta_h, self.decay)

        self.optimizer_pi_f = BehavioralEmbeddedAgent.set_optimizer(
            pi_f_params, 0.0008)
        self.scheduler_pi_f = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_f, self.decay)

        self.optimizer_pi_v = BehavioralEmbeddedAgent.set_optimizer(
            pi_v_params, 0.0008)
        self.scheduler_pi_v = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_v, self.decay)

        self.optimizer_pi_h = BehavioralEmbeddedAgent.set_optimizer(
            pi_h_params, 0.0008)
        self.scheduler_pi_h = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_pi_h, self.decay)

        actions = torch.LongTensor(consts.hotvec_matrix).cuda()
        self.actions_matrix = actions.unsqueeze(0)

        self.q_bins = consts.q_bins[args.game][:-1] / self.meta['avg_score']
        # the long bins are already normalized
        self.v_bins = consts.v_bins[args.game][:-1] / self.meta['avg_score']

        self.q_bins_torch = Variable(torch.from_numpy(
            consts.q_bins[args.game] / self.meta['avg_score']),
                                     requires_grad=False).cuda()
        self.v_bins_torch = Variable(torch.from_numpy(
            consts.v_bins[args.game] / self.meta['avg_score']),
                                     requires_grad=False).cuda()

        self.batch_range = np.arange(self.batch)

        self.zero = Variable(torch.zeros(1))

    def flip_grad(self, parameters):
        for p in parameters:
            p.requires_grad = not p.requires_grad

    @staticmethod
    def individual_loss_fn_l2(argument):
        return abs(argument.data.cpu().numpy())**2

    @staticmethod
    def individual_loss_fn_l1(argument):
        return abs(argument.data.cpu().numpy())

    def save_checkpoint(self, path, aux=None):

        state = {
            'behavioral_model': self.behavioral_model.state_dict(),
            'actor_critic_model': self.actor_critic_model.state_dict(),
            'optimizer_critic_v': self.optimizer_critic_v.state_dict(),
            'optimizer_critic_q': self.optimizer_critic_q.state_dict(),
            'optimizer_v_beta': self.optimizer_v_beta.state_dict(),
            'optimizer_q_beta': self.optimizer_q_beta.state_dict(),
            'optimizer_beta_f': self.optimizer_beta_f.state_dict(),
            'optimizer_beta_v': self.optimizer_beta_v.state_dict(),
            'optimizer_beta_h': self.optimizer_beta_h.state_dict(),
            'optimizer_pi_f': self.optimizer_pi_f.state_dict(),
            'optimizer_pi_v': self.optimizer_pi_v.state_dict(),
            'optimizer_pi_h': self.optimizer_pi_h.state_dict(),
            'aux': aux
        }

        torch.save(state, path)

    def load_checkpoint(self, path):

        state = torch.load(path)
        self.behavioral_model.load_state_dict(state['behavioral_model'])
        self.actor_critic_model.load_state_dict(state['actor_critic_model'])
        self.optimizer_critic_v.load_state_dict(state['optimizer_critic_v'])
        self.optimizer_critic_q.load_state_dict(state['optimizer_critic_q'])
        self.optimizer_v_beta.load_state_dict(state['optimizer_v_beta'])
        self.optimizer_q_beta.load_state_dict(state['optimizer_q_beta'])
        self.optimizer_beta_f.load_state_dict(state['optimizer_beta_f'])
        self.optimizer_beta_v.load_state_dict(state['optimizer_beta_v'])
        self.optimizer_beta_h.load_state_dict(state['optimizer_beta_h'])
        self.optimizer_pi_f.load_state_dict(state['optimizer_pi_f'])
        self.optimizer_pi_v.load_state_dict(state['optimizer_pi_v'])
        self.optimizer_pi_h.load_state_dict(state['optimizer_pi_h'])

        return state['aux']

    def resume(self, model_path):

        aux = self.load_checkpoint(model_path)
        # self.update_target()
        return aux

    def update_target(self):
        self.actor_critic_target.load_state_dict(
            self.actor_critic_model.state_dict())

    def batched_interp(self, x, xp, fp):
        # implemented with numpy
        x = x.data.cpu().numpy()
        xp = xp.data.cpu().numpy()
        fp = fp.data.cpu().numpy()
        y = np.zeros(x.shape)

        for i, (xl, xpl, fpl) in enumerate(zip(x, xp, fp)):
            y[i] = np.interp(xl, xpl, fpl)

        return Variable(torch.FloatTensor().cuda(), requires_grad=False)

    def new_distribution(self, q, beta, r, bin):
        bin = bin.repeat(self.batch, self.global_action_space, 1)
        r = r.unsqueeze(1).repeat(1, bin.shape[0])
        beta = beta.unsqueeze(1)

        # dimensions:
        # bins [batch, actions, bins]
        # beta [batch, 1, actions]
        # new_bin = torch.baddbmm(r, beta, , alpha=self.discount)
        q_back.squeeze(1)
        return self.batched_interp(x, xp, fp)

    def learn(self, n_interval, n_tot):

        self.behavioral_model.train()
        self.actor_critic_model.train()
        self.actor_critic_target.eval()

        results = {
            'n': [],
            'loss_v': [],
            'loss_q': [],
            'loss_beta_f': [],
            'loss_beta_v': [],
            'loss_beta_h': [],
            'loss_pi_s': [],
            'loss_pi_l': [],
            'loss_pi_s_tau': [],
            'loss_pi_l_tau': []
        }

        for n, sample in tqdm(enumerate(self.train_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda(), requires_grad=False)

            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)

            rl = np.digitize(sample['score'].numpy(),
                             self.long_bins,
                             right=True)
            rs = np.digitize(sample['f'].numpy(), self.short_bins, right=True)

            Rl = Variable(sample['score'].cuda(), requires_grad=False)
            Rs = Variable(sample['f'].cuda(), requires_grad=False)

            rl = Variable(torch.from_numpy(rl).cuda(), requires_grad=False)
            rs = Variable(torch.from_numpy(rs).cuda(), requires_grad=False)

            vs, vl, beta, qs, ql, phi, pi_s, pi_l, pi_s_tau, pi_l_tau = self.model(
                s, a)

            # policy learning

            if self.alpha_vs and train_net:
                loss_vs = self.alpha_vs * self.loss_fn_vs(vs, rs)
                self.optimizer_vs.zero_grad()
                loss_vs.backward(retain_graph=True)
                self.optimizer_vs.step()
            else:
                loss_vs = self.zero

            if self.alpha_vl and train_net:
                loss_vl = self.alpha_vl * self.loss_fn_vl(vl, rl)
                self.optimizer_vl.zero_grad()
                loss_vl.backward(retain_graph=True)
                self.optimizer_vl.step()
            else:
                loss_vl = self.zero

            if self.alpha_b and train_net:
                loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)
                self.optimizer_beta.zero_grad()
                loss_b.backward(retain_graph=True)
                self.optimizer_beta.step()
            else:
                loss_b = self.zero

            if self.alpha_qs and train_net:
                loss_qs = self.alpha_qs * self.loss_fn_qs(qs, rs)
                self.optimizer_qs.zero_grad()
                loss_qs.backward(retain_graph=True)
                self.optimizer_qs.step()
            else:
                loss_qs = self.zero

            if self.alpha_ql and train_net:
                loss_ql = self.alpha_ql * self.loss_fn_ql(ql, rl)
                self.optimizer_ql.zero_grad()
                loss_ql.backward(retain_graph=True)
                self.optimizer_ql.step()
            else:
                loss_ql = self.zero

            a_index_np = sample['a_index'].numpy()
            self.batch_range = np.arange(self.batch)

            beta_sfm = F.softmax(beta, 1)
            pi_s_sfm = F.softmax(pi_s, 1)
            pi_l_sfm = F.softmax(pi_l, 1)
            pi_s_tau_sfm = F.softmax(pi_s, 1)
            pi_l_tau_sfm = F.softmax(pi_l, 1)

            beta_fix = Variable(beta_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_fix = Variable(pi_s_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_l_fix = Variable(pi_l_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_tau_fix = Variable(pi_s_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)
            pi_l_tau_fix = Variable(pi_l_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)

            if self.alpha_pi_s and not train_net:
                loss_pi_s = self.alpha_pi_s * self.loss_fn_pi_s(pi_s, a_index)
                loss_pi_s = (loss_pi_s * Rs *
                             self.off_factor(pi_s_fix, beta_fix)).mean()
                self.optimizer_pi_s.zero_grad()
                loss_pi_s.backward(retain_graph=True)
                self.optimizer_pi_s.step()
            else:
                loss_pi_s = self.zero

            if self.alpha_pi_l and not train_net:
                loss_pi_l = self.alpha_pi_l * self.loss_fn_pi_l(pi_l, a_index)
                loss_pi_l = (loss_pi_l * Rl *
                             self.off_factor(pi_l_fix, beta_fix)).mean()
                self.optimizer_pi_l.zero_grad()
                loss_pi_l.backward(retain_graph=True)
                self.optimizer_pi_l.step()
            else:
                loss_pi_l = self.zero

            if self.alpha_pi_s_tau and not train_net:
                loss_pi_s_tau = self.alpha_pi_s_tau * self.loss_fn_pi_s_tau(
                    pi_s_tau, a_index)
                w = self.get_weighted_loss(F.softmax(qs, 1),
                                           self.short_bins_torch)
                loss_pi_s_tau = (
                    loss_pi_s_tau * w *
                    self.off_factor(pi_s_tau_fix, beta_fix)).mean()
                self.optimizer_pi_s_tau.zero_grad()
                loss_pi_s_tau.backward(retain_graph=True)
                self.optimizer_pi_s_tau.step()
            else:
                loss_pi_s_tau = self.zero

            if self.alpha_pi_l_tau and not train_net:
                loss_pi_l_tau = self.alpha_pi_l_tau * self.loss_fn_pi_l_tau(
                    pi_l_tau, a_index)
                w = self.get_weighted_loss(F.softmax(ql, 1),
                                           self.long_bins_torch)
                loss_pi_l_tau = (
                    loss_pi_l_tau * w *
                    self.off_factor(pi_l_tau_fix, beta_fix)).mean()
                self.optimizer_pi_l_tau.zero_grad()
                loss_pi_l_tau.backward()
                self.optimizer_pi_l_tau.step()
            else:
                loss_pi_l_tau = self.zero

            # add results
            results['loss_vs'].append(loss_vs.data.cpu().numpy()[0])
            results['loss_vl'].append(loss_vl.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_qs'].append(loss_qs.data.cpu().numpy()[0])
            results['loss_ql'].append(loss_ql.data.cpu().numpy()[0])
            results['loss_pi_s'].append(loss_pi_s.data.cpu().numpy()[0])
            results['loss_pi_l'].append(loss_pi_l.data.cpu().numpy()[0])
            results['loss_pi_s_tau'].append(
                loss_pi_s_tau.data.cpu().numpy()[0])
            results['loss_pi_l_tau'].append(
                loss_pi_l_tau.data.cpu().numpy()[0])
            results['n'].append(n)

            # if not n % self.update_target_interval:
            #     # self.update_target()

            # if an index is rolled more than once during update_memory_interval period, only the last occurance affect the
            if not (
                    n + 1
            ) % self.update_memory_interval and self.prioritized_replay:
                self.train_dataset.update_probabilities()

            # update a global n_step parameter

            if not (n + 1) % self.update_n_steps_interval:
                # self.train_dataset.update_n_step(n + 1)
                d = np.divmod(n + 1, self.update_n_steps_interval)[0]
                if d % 10 == 1:
                    self.flip_grad(self.parameters_group_b +
                                   self.parameters_group_a)
                    train_net = not train_net
                if d % 10 == 2:
                    self.flip_grad(self.parameters_group_b +
                                   self.parameters_group_a)
                    train_net = not train_net

                    self.scheduler_pi_s.step()
                    self.scheduler_pi_l.step()
                    self.scheduler_pi_s_tau.step()
                    self.scheduler_pi_l_tau.step()
                else:
                    self.scheduler_vs.step()
                    self.scheduler_beta.step()
                    self.scheduler_vl.step()
                    self.scheduler_qs.step()
                    self.scheduler_ql.step()

            if not (n + 1) % n_interval:
                yield results
                self.model.train()
                # self.target.eval()
                results = {key: [] for key in results}

    def off_factor(self, pi, beta):
        return torch.clamp(pi / beta, 0, 1)

    def test(self, n_interval, n_tot):

        self.model.eval()
        # self.target.eval()

        results = {
            'n': [],
            'loss_vs': [],
            'loss_b': [],
            'loss_vl': [],
            'loss_qs': [],
            'loss_ql': [],
            'act_diff': [],
            'a_agent': [],
            'a_player': [],
            'loss_pi_s': [],
            'loss_pi_l': [],
            'loss_pi_s_tau': [],
            'loss_pi_l_tau': []
        }

        for n, sample in tqdm(enumerate(self.test_loader)):

            s = Variable(sample['s'].cuda(), requires_grad=False)
            a = Variable(sample['a'].cuda().unsqueeze(1), requires_grad=False)

            a_index = Variable(sample['a_index'].cuda(async=True),
                               requires_grad=False)

            rl = np.digitize(sample['score'].numpy(),
                             self.long_bins,
                             right=True)
            rs = np.digitize(sample['f'].numpy(), self.short_bins, right=True)

            Rl = Variable(sample['score'].cuda(), requires_grad=False)
            Rs = Variable(sample['f'].cuda(), requires_grad=False)

            rl = Variable(torch.from_numpy(rl).cuda(), requires_grad=False)
            rs = Variable(torch.from_numpy(rs).cuda(), requires_grad=False)

            vs, vl, beta, qs, ql, phi, pi_s, pi_l, pi_s_tau, pi_l_tau = self.model(
                s, a)

            qs = qs.squeeze(1)
            ql = ql.squeeze(1)

            # policy learning

            loss_vs = self.alpha_vs * self.loss_fn_vs(vs, rs)
            loss_vl = self.alpha_vl * self.loss_fn_vl(vl, rl)
            loss_b = self.alpha_b * self.loss_fn_beta(beta, a_index)
            loss_qs = self.alpha_qs * self.loss_fn_qs(qs, rs)
            loss_ql = self.alpha_ql * self.loss_fn_ql(ql, rl)

            a_index_np = sample['a_index'].numpy()
            self.batch_range = np.arange(self.batch)

            beta_sfm = F.softmax(beta, 1)
            pi_s_sfm = F.softmax(pi_s, 1)
            pi_l_sfm = F.softmax(pi_l, 1)
            pi_s_tau_sfm = F.softmax(pi_s, 1)
            pi_l_tau_sfm = F.softmax(pi_l, 1)

            beta_fix = Variable(beta_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_fix = Variable(pi_s_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_l_fix = Variable(pi_l_sfm.data[self.batch_range, a_index_np],
                                requires_grad=False)
            pi_s_tau_fix = Variable(pi_s_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)
            pi_l_tau_fix = Variable(pi_l_tau_sfm.data[self.batch_range,
                                                      a_index_np],
                                    requires_grad=False)

            loss_pi_s = self.alpha_pi_s * self.loss_fn_pi_s(pi_s, a_index)
            loss_pi_s = (loss_pi_s * Rs *
                         self.off_factor(pi_s_fix, beta_fix)).mean()

            loss_pi_l = self.alpha_pi_l * self.loss_fn_pi_l(pi_l, a_index)
            loss_pi_l = (loss_pi_l * Rl *
                         self.off_factor(pi_l_fix, beta_fix)).mean()

            loss_pi_s_tau = self.alpha_pi_s_tau * self.loss_fn_pi_s_tau(
                pi_s_tau, a_index)
            w = self.get_weighted_loss(F.softmax(qs, 1), self.short_bins_torch)
            loss_pi_s_tau = (loss_pi_s_tau * w *
                             self.off_factor(pi_s_tau_fix, beta_fix)).mean()

            loss_pi_l_tau = self.alpha_pi_l_tau * self.loss_fn_pi_l_tau(
                pi_l_tau, a_index)
            w = self.get_weighted_loss(F.softmax(ql, 1), self.long_bins_torch)
            loss_pi_l_tau = (loss_pi_l_tau * w *
                             self.off_factor(pi_l_tau_fix, beta_fix)).mean()

            # collect actions statistics
            a_index_np = a_index.data.cpu().numpy()

            _, beta_index = beta.data.cpu().max(1)
            beta_index = beta_index.numpy()
            act_diff = (a_index_np != beta_index).astype(np.int)

            # add results
            results['act_diff'].append(act_diff)
            results['a_agent'].append(beta_index)
            results['a_player'].append(a_index_np)
            results['loss_vs'].append(loss_vs.data.cpu().numpy()[0])
            results['loss_vl'].append(loss_vl.data.cpu().numpy()[0])
            results['loss_b'].append(loss_b.data.cpu().numpy()[0])
            results['loss_qs'].append(loss_qs.data.cpu().numpy()[0])
            results['loss_ql'].append(loss_ql.data.cpu().numpy()[0])
            results['loss_pi_s'].append(loss_pi_s.data.cpu().numpy()[0])
            results['loss_pi_l'].append(loss_pi_l.data.cpu().numpy()[0])
            results['loss_pi_s_tau'].append(
                loss_pi_s_tau.data.cpu().numpy()[0])
            results['loss_pi_l_tau'].append(
                loss_pi_l_tau.data.cpu().numpy()[0])
            results['n'].append(n)

            if not (n + 1) % n_interval:
                results['s'] = s.data.cpu()
                results['act_diff'] = np.concatenate(results['act_diff'])
                results['a_agent'] = np.concatenate(results['a_agent'])
                results['a_player'] = np.concatenate(results['a_player'])
                yield results
                self.model.eval()
                # self.target.eval()
                results = {key: [] for key in results}

    def play_stochastic(self, n_tot):
        raise NotImplementedError

    def play_episode(self, n_tot):

        self.model.eval()
        env = Env()

        n_human = 120
        humans_trajectories = iter(self.data)
        softmax = torch.nn.Softmax()

        # mask = torch.FloatTensor(consts.actions_mask[args.game])
        # mask = Variable(mask.cuda(), requires_grad=False)

        vsx = torch.FloatTensor(consts.short_bins[args.game])
        vlx = torch.FloatTensor(consts.long_bins[args.game])

        for i in range(n_tot):

            env.reset()
            observation = next(humans_trajectories)
            trajectory = self.data[observation]
            choices = np.arange(self.global_action_space, dtype=np.int)

            j = 0

            while not env.t:

                s = Variable(env.s.cuda(), requires_grad=False)
                vs, vl, beta, qs, ql, phi, pi_s, pi_l, pi_s_tau, pi_l_tau = self.model(
                    s, self.actions_matrix)
                beta = beta.squeeze(0)
                pi_l = pi_l.squeeze(0)
                pi_s = pi_s.squeeze(0)
                pi_l_tau = pi_l_tau.squeeze(0)
                pi_s_tau = pi_s_tau.squeeze(0)

                temp = 1

                # consider only 3 most frequent actions
                beta_np = beta.data.cpu().numpy()
                indices = np.argsort(beta_np)

                maskb = Variable(torch.FloatTensor(
                    [0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
                                 requires_grad=False).cuda()
                # maskb = Variable(torch.FloatTensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
                #                  requires_grad=False).cuda()

                # pi = maskb * (beta / beta.max())

                pi = beta
                self.greedy = True

                beta_prob = pi

                if j < n_human:
                    a = trajectory[j, self.meta['action']]

                else:
                    eps = np.random.rand()
                    # a = np.random.choice(choices)
                    if self.greedy and eps > 0.1:
                        a = pi.data.cpu().numpy()
                        a = np.argmax(a)
                    else:
                        a = softmax(pi / temp).data.cpu().numpy()
                        a = np.random.choice(choices, p=a)

                env.step(a)

                vs = softmax(vs)
                vl = softmax(vl)
                vs = torch.sum(vsx * vs.data.cpu())
                vl = torch.sum(vlx * vl.data.cpu())

                yield {
                    'o': env.s.cpu().numpy(),
                    'vs': np.array([vs]),
                    'vl': np.array([vl]),
                    's': phi.data.cpu().numpy(),
                    'score': env.score,
                    'beta': beta_prob.data.cpu().numpy(),
                    'phi': phi.squeeze(0).data.cpu().numpy(),
                    'qs': qs.squeeze(0).data.cpu().numpy(),
                    'ql': ql.squeeze(0).data.cpu().numpy(),
                }

                j += 1

        raise StopIteration

    def policy(self, vs, vl, beta, qs, ql):
        pass