Пример #1
0
def main(cfg):
    device = set_env(cfg)

    logging.info('Loading the dataset.')
    train_criterion, val_criterion = get_criterion(cfg.optimization.criterion)
    train_dataloader, val_dataloader = get_dataloader(cfg)

    model = MLP(**cfg.network).to(device)
    logging.info(f'Constructing model on the {device}:{cfg.CUDA_DEVICE}.')
    logging.info(model)

    # Set total steps for onecycleLR and cosineLR
    cfg.optimization.total_steps = len(train_dataloader) * cfg.optimization.epoch
    cfg.optimization.onecycle_scheduler.total_steps = \
        cfg.optimization.cosine_scheduler.T_max = cfg.optimization.total_steps

    optimizer, scheduler = get_optimization(cfg, model)

    best_loss = float("inf")
    for epoch in range(cfg.optimization.epoch):
        train_epoch(model,
                    train_dataloader,
                    train_criterion,
                    optimizer,
                    scheduler,
                    device,
                    epoch,
                    cfg)
        if cfg.optimization.scheduler in ['exp', 'step']:
            scheduler.step()

        val_loss = valid_epoch(model,
                               val_dataloader,
                               val_criterion,
                               device,
                               epoch,
                               cfg)
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), '{}/best_model.pth'.format(cfg.log_dir))
            logging.info(f'New Main Loss {best_loss}')
    torch.save(model.state_dict(), '{}/last_model.pth'.format(cfg.log_dir))

    logging.info('Best Main Loss {}'.format(best_loss))
    logging.info(cfg)
    pickle.dump(cfg, open('{}/config.pkl'.format(cfg.log_dir), 'wb'))
Пример #2
0
class NormalPolicy():
    def __init__(self, layers, sigma, activation=F.relu):
        self.mu_net = MLP(layers, activation)
        self.sigma = MLP(layers, activation=F.softplus)

        # self.mu_net.fc1.weight.data = torch.zeros(self.mu_net.fc1.weight.data.shape)
        # self.mu_net.eta.data = torch.ones(1) * 2

    def get_mu(self, states):
        return self.mu_net.forward(states)

    def get_sigma(self, states):
        return self.sigma.forward(states)

    def get_action(self, state):
        # random action if untrained
        # if self.initial_policy is not None:
        #     return self.initial_policy.get_action(state)
        # sample from normal otherwise
        if state.dim() < 2:
            state.unsqueeze_(0)
        mean = self.get_mu(state)
        std_dev = self.get_sigma(state)
        mean.squeeze()
        std_dev.squeeze()
        m = torch.normal(mean, std_dev)
        return m.data

    def optimize(self, max_epochs_opt, train_dataset, val_dataset, batch_size, learning_rate, verbose=False):
        # init data loader
        train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
        # init optimizers
        optimizer_mu = optim.Adagrad([{'params': self.mu_net.parameters()}, {'params':self.sigma.parameters()}], lr=learning_rate)
        # train on batches
        best_model = None
        last_loss_opt = None
        epochs_opt_no_decrease = 0
        epoch_opt = 0
        while (epoch_opt < max_epochs_opt) and (epochs_opt_no_decrease < 3):
            for batch_idx, batch in enumerate(train_data_loader):
                optimizer_mu.zero_grad()

                # forward pass
                mu = self.mu_net(batch[0])
                sigma = self.get_sigma(batch[0])
                loss = NormalPolicyLoss(mu, sigma, batch[1], batch[2])
                # backpropagate
                loss.backward()
                optimizer_mu.step()
            # calculate loss on validation data
            mu = self.get_mu(val_dataset[0])
            sigma = self.get_sigma(val_dataset[0])
            cur_loss_opt = NormalPolicyLoss(mu, sigma, val_dataset[1], val_dataset[2])
            # evaluate optimization iteration

            if verbose:
                sys.stdout.write('\r[policy] epoch: %d | loss: %f' % (epoch_opt+1, cur_loss_opt))
                sys.stdout.flush()
            if (last_loss_opt is None) or (cur_loss_opt < last_loss_opt - 1e-3):
                best_model = self.mu_net.state_dict()
                epochs_opt_no_decrease = 0
                last_loss_opt = cur_loss_opt
            else:
                epochs_opt_no_decrease += 1
            epoch_opt += 1
        self.mu_net.load_state_dict(best_model)
        if verbose: sys.stdout.write('\r[policy] training complete (%d epochs, %f best loss)' % (epoch_opt, last_loss_opt) + (' ' * (len(str(epoch_opt)))*2 + '\n'))