def main(args):
    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False)
    else:
        raise Exception('unrecognised model architecture: ' + args.arch)

    model = DataParallel(model).to(device)

    optimizer = RMSprop(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

    best_acc = 0

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # create data loader
    train_dataset = Mpii(args.image_path, is_train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    val_dataset = Mpii(args.image_path, is_train=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    # train and eval
    lr = args.lr
    for epoch in trange(args.start_epoch,
                        args.epochs,
                        desc='Overall',
                        ascii=True):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader,
                                                  model,
                                                  device,
                                                  Mpii.DATA_INFO,
                                                  optimizer,
                                                  acc_joints=Mpii.ACC_JOINTS)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = do_validation_epoch(
            val_loader,
            model,
            device,
            Mpii.DATA_INFO,
            False,
            acc_joints=Mpii.ACC_JOINTS)

        # print metrics
        tqdm.write(
            f'[{epoch + 1:3d}/{args.epochs:3d}] lr={lr:0.2e} '
            f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} '
            f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}')

        # append logger file
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])
        logger.plot_to_file(os.path.join(args.checkpoint, 'log.svg'),
                            ['Train Acc', 'Val Acc'])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            predictions,
            is_best,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()
Exemplo n.º 2
0
            batchSamples = torch.from_numpy(
                np.array(oneBatchSamples)[index]).long().to(device)
            batchLabels = torch.from_numpy(
                np.array(oneBatchLabels)[index]).long().to(device)
            optimizer.zero_grad()
            predictTensor = model(batchSamples)
            loss = lossCri(predictTensor, batchLabels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            if trainingTimes % display_step == 0:
                print("#################")
                print("Predict tensor is ", predictTensor)
                print("Labels are ", batchLabels)
                print("Learning rate is ",
                      optimizer.state_dict()['param_groups'][0]["lr"])
                print("Loss is ", loss)
                print("Training time is ", trainingTimes)
            learning_rate = scheduler.calculateLearningRate()
            state_dic = optimizer.state_dict()
            state_dic["param_groups"][0]["lr"] = float(learning_rate)
            optimizer.load_state_dict(state_dic)
            trainingTimes += 1
            if trainingTimes % save_model_steps == 0:
                torch.save(
                    model.state_dict(),
                    weight_save_path + "ALBERT_" + str(trainingTimes) + ".pth")
else:
    model.eval()
    model.load_state_dict(
        torch.load(weight_save_path + "ALBERT_" + str(testModelSelect) +
Exemplo n.º 3
0
class DQNAgent(TrainingAgent):
    def __init__(self, input_shape, action_space, seed, device, model, gamma,
                 alpha, tau, batch_size,update, replay, buffer_size, env,
                 decay = 200, path = 'model',num_epochs= 0, max_step = 50000, learn_interval = 20):

        '''Initialise a DQNAgent Object
        buffer_size : size of replay buffer to sample from
        gamma       : discount rate
        alpha       : learn rate
        replay.     : after which replay buffer loading to be started
        update      : update interval of model parameters every x instances of back propagation
        replay.     : after which replay buffer loading to be started
        learn_interval: tick for learning rate
        '''
        super(DQNAgent,self).__init__( input_shape ,action_space ,seed ,device,model,
                                        gamma, alpha, tau, batch_size, max_step, env,num_epochs ,path)
        self.buffer_size = buffer_size
        self.update = update
        self.replay = replay
        self.interval = learn_interval
        # Q-Network
        self.policy_net = self.model(input_shape, action_space).to(self.device)
        self.target_net = self.model(input_shape, action_space).to(self.device)
        self.optimiser = RMSprop(self.policy_net.parameters(), lr=self.alpha)
        # Replay Memory
        self.memory = ReplayMemory(self.buffer_size, self.batch_size, self.seed, self.device)
        # Timestep
        self.t_step = 0
        self.l_step = 0

        self.EPSILON_START = 1.0
        self.EPSILON_FINAL = 0.02
        self.EPS_DECAY = decay
        self.epsilon_delta = lambda frame_idx: self.EPSILON_FINAL + (self.EPSILON_START - self.EPSILON_FINAL) * exp(-1. * frame_idx / self.EPS_DECAY)

    def step(self, state, action, reward, next_state, done):
        '''
        Step of learning and taking environment action.
        '''

        # Save experience into replay buffer
        self.memory.add(state, action, reward, next_state, done)

        # Learn every update % timestep
        self.t_step = (self.t_step + 1) % self.interval

        if self.t_step == 0:
            # if there are enough samples in the memory, get a random subset and learn
            if len(self.memory) > self.replay:
                experience = self.memory.sample()
                print('learning')
                self.learn(experience)


    def action(self, state, eps=0.):
        ''' Returns action for given state as per current policy'''
        #Unpack the state
        state = torch.from_numpy(state).unsqueeze(0).to(self.device)
        if rand.rand() > eps:
            # Eps Greedy action selections
            action_val = self.policy_net(state)
            return np.argmax(action_val.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_space))

    def learn(self, exp):
        state, action, reward, next_state, done = exp

        # Get expected Q values from Policy Model
        Q_expt_current = self.policy_net(state)
        Q_expt = Q_expt_current.gather(1, action.unsqueeze(1)).squeeze(1)

        # Get max predicted Q values for next state from target model
        Q_target_next = self.target_net(next_state).detach().max(1)[0]
        # Compute Q targets for current states
        Q_target = reward + (self.gamma * Q_target_next * (1 - done))

        # Compute Loss
        loss = torch.nn.functional.mse_loss(Q_expt, Q_target)

        # Minimize loss
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        self.l_step = (self.l_step +1) % self.update
        if self.t_step == 0:
            self.soft_update(self.policy_net, self.target_net, self.tau)

    def model_dict(self)-> dict:
        ''' To save models'''
        return {'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(),
                'optimizer': self.optimiser.state_dict(), 'num_epoch': self.num_epochs,'scores': self.scores}

    def load_model(self, state_dict,eval = True):
        '''Load Parameters and Model Information from prior training for continuation of training'''
        self.policy_net.load_state_dict(state_dict['policy_net'])
        self.target_net.load_state_dict(state_dict['target_net'])
        self.optimiser.load_state_dict(state_dict['optimizer'])
        self.scores = state_dict['scores']
        if eval:
            self.policy_net.eval()
            self.target_net.eval()
        else:
            self.policy_net.train()
            self.target_net.train()
        #Load the model
        self.num_epochs = state_dict['num_epoch']

    # θ'=θ×τ+θ'×(1−τ)
    def soft_update(self, policy_model, target_model, tau):
        for t_param, p_param in zip(target_model.parameters(), policy_model.parameters()):
            t_param.data.copy_(tau * p_param.data + (1.0 - tau) * t_param.data)

    def train(self, n_episodes=1000,render= False):
        """
        n_episodes: maximum number of training episodes
        Saves Model every 100 Epochs
        """
        filename = get_filename()

        self.env.render(render)
        # Toggles the render on
        for i_episode in range(n_episodes):
            self.num_epochs += 1
            state = self.stack_frames(None, self.reset(), True)
            score = 0
            eps = self.epsilon_delta(self.num_epochs)

            while True:
                action = self.action(state, eps)

                next_state, reward, done, info = self.env.step(action)

                score += reward

                next_state = self.stack_frames(state, next_state, False)

                self.step(state, action, reward, next_state, done)
                state = next_state
                if done:
                    break
            self.scores.append(score)  # save most recent score

            # Every 100 training
            if i_episode % 100 == 0:
                self.save_obj(self.model_dict(), os.path.join(self.path, filename))
                print(f"Creating plot")
                # Plot a figure
                fig = plt.figure()

                # Add a subplot
                # ax = fig.add_subplot(111)

                # Plot the graph
                plt.plot(np.arange(len(self.scores)), self.scores)

                # Add labels
                plt.xlabel('Episode #')
                plt.ylabel('Score')

                # Save the plot
                plt.savefig(f'{i_episode} plot.png')
                print(f"Plot saved")

        # Return the scores.
        return self.scores
Exemplo n.º 4
0
class OffPGLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mac = mac
        self.logger = logger

        self.last_target_update_step = 0
        self.critic_training_steps = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = OffPGCritic(scheme, args)
        self.mixer = QMixer(args)
        self.target_critic = copy.deepcopy(self.critic)
        self.target_mixer = copy.deepcopy(self.mixer)

        self.agent_params = list(mac.parameters())
        self.critic_params = list(self.critic.parameters())
        self.mixer_params = list(self.mixer.parameters())
        self.params = self.agent_params + self.critic_params
        self.c_params = self.critic_params + self.mixer_params

        self.agent_optimiser =  RMSprop(params=self.agent_params, lr=args.lr)
        self.critic_optimiser =  RMSprop(params=self.critic_params, lr=args.lr)
        self.mixer_optimiser =  RMSprop(params=self.mixer_params, lr=args.lr)

        print('Mixer Size: ')
        print(get_parameters_num(list(self.c_params)))

    def train(self, batch: EpisodeBatch, t_env: int, log):
        # Get the relevant quantities
        bs = batch.batch_size
        max_t = batch.max_seq_length
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        avail_actions = batch["avail_actions"][:, :-1]
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents).view(-1)
        states = batch["state"][:, :-1]

        #build q
        inputs = self.critic._build_inputs(batch, bs, max_t)
        q_vals = self.critic.forward(inputs).detach()[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        # Calculated baseline
        q_taken = th.gather(q_vals, dim=3, index=actions).squeeze(3)
        pi = mac_out.view(-1, self.n_actions)
        baseline = th.sum(mac_out * q_vals, dim=-1).view(-1).detach()

        # Calculate policy grad with mask
        pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = th.log(pi_taken)
        coe = self.mixer.k(states).view(-1)

        advantages = (q_taken.view(-1) - baseline)
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        coma_loss = - ((coe * advantages.detach() * log_pi_taken) * mask).sum() / mask.sum()
        
        # dist_entropy = Categorical(pi).entropy().view(-1)
        # dist_entropy[mask == 0] = 0 # fill nan
        # entropy_loss = (dist_entropy * mask).sum() / mask.sum()
 
        # loss = coma_loss - self.args.ent_coef * entropy_loss / entropy_loss.item()
        loss = coma_loss

        # Optimise agents
        self.agent_optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip)
        self.agent_optimiser.step()

        #compute parameters sum for debugging
        p_sum = 0.
        for p in self.agent_params:
            p_sum += p.data.abs().sum().item() / 100.0


        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(log["critic_loss"])
            for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean", "q_max_mean", "q_min_mean", "q_max_var", "q_min_var"]:
                self.logger.log_stat(key, sum(log[key])/ts_logged, t_env)
            self.logger.log_stat("q_max_first", log["q_max_first"], t_env)
            self.logger.log_stat("q_min_first", log["q_min_first"], t_env)
            #self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env)
            # self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env)
            self.logger.log_stat("coma_loss", coma_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def train_critic(self, on_batch, best_batch=None, log=None):
        bs = on_batch.batch_size
        max_t = on_batch.max_seq_length
        rewards = on_batch["reward"][:, :-1]
        actions = on_batch["actions"][:, :]
        terminated = on_batch["terminated"][:, :-1].float()
        mask = on_batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = on_batch["avail_actions"][:]
        states = on_batch["state"]

        #build_target_q
        target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t)
        target_q_vals = self.target_critic.forward(target_inputs).detach()
        targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)
        target_q = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).detach()

        inputs = self.critic._build_inputs(on_batch, bs, max_t)


        if best_batch is not None:
            best_target_q, best_inputs, best_mask, best_actions, best_mac_out= self.train_critic_best(best_batch)
            log["best_reward"] = th.mean(best_batch["reward"][:, :-1].squeeze(2).sum(-1), dim=0)
            target_q = th.cat((target_q, best_target_q), dim=0)
            inputs = th.cat((inputs, best_inputs), dim=0)
            mask = th.cat((mask, best_mask), dim=0)
            actions = th.cat((actions, best_actions), dim=0)
            states = th.cat((states, best_batch["state"]), dim=0)

        #train critic
        for t in range(max_t - 1):
            mask_t = mask[:, t:t+1]
            if mask_t.sum() < 0.5:
                continue
            q_vals = self.critic.forward(inputs[:, t:t+1])
            q_ori = q_vals
            q_vals = th.gather(q_vals, 3, index=actions[:, t:t+1]).squeeze(3)
            q_vals = self.mixer.forward(q_vals, states[:, t:t+1])
            target_q_t = target_q[:, t:t+1].detach()
            q_err = (q_vals - target_q_t) * mask_t
            critic_loss = (q_err ** 2).sum() / mask_t.sum()

            self.critic_optimiser.zero_grad()
            self.mixer_optimiser.zero_grad()
            critic_loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.c_params, self.args.grad_norm_clip)
            self.critic_optimiser.step()
            self.mixer_optimiser.step()
            self.critic_training_steps += 1

            log["critic_loss"].append(critic_loss.item())
            log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems))
            log["target_mean"].append((target_q_t * mask_t).sum().item() / mask_elems)
            log["q_taken_mean"].append((q_vals * mask_t).sum().item() / mask_elems)
            log["q_max_mean"].append((th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_min_mean"].append((th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_max_var"].append((th.var(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_min_var"].append((th.var(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)

            if (t == 0):
                log["q_max_first"] = (th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems
                log["q_min_first"] = (th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems

        #update target network
        if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps



    def train_critic_best(self, batch):
        bs = batch.batch_size
        max_t = batch.max_seq_length
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:]
        states = batch["state"]

        with th.no_grad():
            # pr for all actions of the episode
            mac_out = []
            self.mac.init_hidden(bs)
            for i in range(max_t):
                agent_outs = self.mac.forward(batch, t=i)
                mac_out.append(agent_outs)
            mac_out = th.stack(mac_out, dim=1).detach()
            # Mask out unavailable actions, renormalise (as in action selection)
            mac_out[avail_actions == 0] = 0
            mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
            mac_out[avail_actions == 0] = 0
            critic_mac = th.gather(mac_out, 3, actions).squeeze(3).prod(dim=2, keepdim=True)

            #target_q take
            target_inputs = self.target_critic._build_inputs(batch, bs, max_t)
            target_q_vals = self.target_critic.forward(target_inputs).detach()
            targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)

            #expected q
            exp_q = self.build_exp_q(target_q_vals, mac_out, states).detach()
            # td-error
            targets_taken[:, -1] = targets_taken[:, -1] * (1 - th.sum(terminated, dim=1))
            exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1))
            targets_taken[:, :-1] = targets_taken[:, :-1] * mask
            exp_q[:, :-1] = exp_q[:, :-1] * mask
            td_q = (rewards + self.args.gamma * exp_q[:, 1:] - targets_taken[:, :-1]) * mask

            #compute target
            target_q =  build_target_q(td_q, targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.tb_lambda, self.args.step).detach()

            inputs = self.critic._build_inputs(batch, bs, max_t)

        return target_q, inputs, mask, actions, mac_out


    def build_exp_q(self, target_q_vals, mac_out, states):
        target_exp_q_vals = th.sum(target_q_vals * mac_out, dim=3)
        target_exp_q_vals = self.target_mixer.forward(target_exp_q_vals, states)
        return target_exp_q_vals

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.mixer.cuda()
        self.target_critic.cuda()
        self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path))
        th.save(self.mixer_optimiser.state_dict(), "{}/mixer_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage))
        self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
       # self.target_critic.load_state_dict(self.critic.agent.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.mixer_optimiser.load_state_dict(th.load("{}/mixer_opt.th".format(path), map_location=lambda storage, loc: storage))