Exemplo n.º 1
0
        loss = criterion(log_outputs, labels)
        loss.backward()
        optimizer.step()

        accs.append(metrics.logit2acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()[0]

    logger.add(epoch, tr_loss=training_loss/steps, tr_acc=np.mean(accs))

    # Ens 100 test
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=100)
    logger.add(epoch, te_nll_ens100=nll, te_acc_ens100=acc)

    # Stochastic test
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=1)
    logger.add(epoch, te_nll_stoch=nll, te_acc_stoch=acc)

    # Test-time averaging
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=10)
    logger.add(epoch, te_nll_ens10=nll, te_acc_ens10=acc)

    logger.add(epoch, time=time()-t0)
    logger.iter_info()
    logger.save(silent=True)
    torch.save(net.state_dict(), logger.checkpoint)

logger.save()
Exemplo n.º 2
0
class BaseAgent(object):
    def __init__(self, config):
        torch.manual_seed(config.experiment.seed)
        self.config = config

        self.env = config.env.init_env()

        self.eps = np.finfo(np.float32).eps.item()
        self.device = config.training.device

        self.memory = Memory()
        self.logger = Logger(config)

        self.model = None
        self.policy = None

        self.episode = 1
        self.episode_steps = 0

    def convert_data(self, x):
        if type(x) == torch.Tensor:
            return x.data.cpu().tolist()
        elif type(x) == bool:
            return int(x)
        elif type(x) == np.ndarray:
            return list(x)
        else:
            return x

    def log(self, run_avg):
        print("Episode {} \t avg length: {} \t reward: {}".format(
            self.episode,
            round(run_avg.get_value("moves"), 2),
            round(run_avg.get_value("return"), 2),
        ))
        self.logger.save()
        self.logger.save_checkpoint(self)
        if self.config.experiment.save_episode_data:
            self.logger.save_episode_data(self.episode)
        self.logger.plot("return")
        self.logger.plot("moves")

    def collect_samples(self, run_avg, timestep):
        num_steps = 0
        while num_steps < self.config.training.update_every:
            print("Starting episode {} at timestep {}".format(
                self.episode, timestep + num_steps))
            episode_data, episode_return, episode_length = self.sample_episode(
                self.episode, timestep + num_steps, run_avg)
            num_steps += episode_length
            run_avg.update_variable("return", episode_return)
            run_avg.update_variable("moves", episode_length)
            self.episode += 1

            if (self.config.experiment.save_episode_data
                    and self.episode % self.config.experiment.every_n_episodes
                    == 0):
                print("Pushed episode data at episode: {}".format(
                    self.episode))
                self.logger.push_episode_data(episode_data)
            if self.episode % self.config.experiment.log_interval == 0:
                self.log(run_avg)

        return num_steps

    def sample_episode(self, episode, step, run_avg):
        episode_return = 0
        self.episode_steps = 0
        episode_data = defaultdict(list, {"episode": int(episode)})
        state = self.env.reset()
        for t in range(self.config.training.max_episode_length):
            state = torch.from_numpy(state).float().to(self.device)
            with torch.no_grad():
                transition, state, done = self.step(state)
            for key, val in transition.items():
                episode_data[key].append(self.convert_data(val))
            episode_return += transition["reward"]
            if self.config.experiment.render:
                self.env.render()
            if (step + t +
                    1) % self.config.experiment.num_steps_between_plot == 0:
                summary = {
                    "steps": step + t + 1,
                    "return": run_avg.get_value("return"),
                    "moves": run_avg.get_value("moves"),
                }
                self.logger.push(summary)

            self.episode_steps += 1
            if self.episode_steps == self.config.training.max_episode_length:
                done = True
            if done:
                break
        episode_length = t + 1
        return episode_data, episode_return, episode_length

    def train(self):
        run_avg = RunningAverage()
        timestep = 0
        print("Max Timesteps: {}".format(self.config.training.max_timesteps))
        while timestep <= self.config.training.max_timesteps:
            num_steps = self.collect_samples(run_avg, timestep)
            timestep += num_steps
            self.update()
            self.memory.clear()
        print("Training complete")

    def evaluate(self):
        timestep = 0
        run_avg = RunningAverage()
        # Iterate through episodes
        while timestep <= self.config.eval.n_eval_steps:
            num_steps = self.collect_samples(run_avg, timestep)
            timestep += num_steps
        print("Evaluation complete")

    def step(self):
        return NotImplementedError

    def update(self):
        return NotImplementedError
Exemplo n.º 3
0
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        accs.append(metrics.logit2acc(outputs.data, labels))  # probably a bad way to calculate accuracy
        training_loss += loss.cpu().data.numpy()[0]

    logger.add(epoch, tr_loss=training_loss/steps, tr_acc=np.mean(accs))

    # Deterministic test
    net.eval()
    acc, nll = utils.evaluate(net, testloader, num_ens=1)
    logger.add(epoch, te_nll_det=nll, te_acc_det=acc)

    # Stochastic test
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=1)
    logger.add(epoch, te_nll_stoch=nll, te_acc_stoch=acc)

    # Test-time averaging
    net.train()
    acc, nll = utils.evaluate(net, testloader, num_ens=20)
    logger.add(epoch, te_nll_ens=nll, te_acc_ens=acc)

    logger.add(epoch, time=time()-t0)
    logger.iter_info()
    logger.save(silent=True)
    torch.save(net.state_dict(), logger.checkpoint)

logger.save()
Exemplo n.º 4
0
def main():
    fmt = {
        'tr_loss': '3.1e',
        'tr_acc': '.4f',
        'te_acc_det': '.4f',
        'te_acc_stoch': '.4f',
        'te_acc_ens': '.4f',
        'te_acc_perm_sigma': '.4f',
        'te_acc_zero_mean': '.4f',
        'te_acc_perm_sigma_ens': '.4f',
        'te_acc_zero_mean_ens': '.4f',
        'te_nll_det': '.4f',
        'te_nll_stoch': '.4f',
        'te_nll_ens': '.4f',
        'te_nll_perm_sigma': '.4f',
        'te_nll_zero_mean': '.4f',
        'te_nll_perm_sigma_ens': '.4f',
        'te_nll_zero_mean_ens': '.4f',
        'time': '.3f'
    }
    fmt = {**fmt, **{'la%d' % i: '.4f' for i in range(4)}}
    args = get_args()
    logger = Logger("lenet5-VDO", fmt=fmt)

    trainset = torchvision.datasets.MNIST(root='./data',
                                          train=True,
                                          download=True,
                                          transform=transforms.ToTensor())
    train_sampler = torch.utils.data.BatchSampler(
        torch.utils.data.RandomSampler(trainset),
        batch_size=args.batch_size,
        drop_last=False)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    testset = torchvision.datasets.MNIST(root='./data',
                                         train=False,
                                         download=True,
                                         transform=transforms.ToTensor())
    test_sampler = torch.utils.data.BatchSampler(
        torch.utils.data.SequentialSampler(testset),
        batch_size=args.batch_size,
        drop_last=False)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_sampler=test_sampler,
                                             num_workers=args.workers,
                                             pin_memory=True)

    net = LeNet5()
    net = net.to(device=args.device, dtype=args.dtype)
    if args.print_model:
        logger.print(net)
    criterion = metrics.SGVLB(net, len(trainset)).to(device=args.device,
                                                     dtype=args.dtype)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    epochs = args.epochs
    lr_start = args.learning_rate
    for epoch in trange(epochs):  # loop over the dataset multiple times
        t0 = time()
        utils.adjust_learning_rate(
            optimizer, metrics.lr_linear(epoch, 0, epochs, lr_start))
        net.train()
        training_loss = 0
        accs = []
        steps = 0
        for i, (inputs, labels) in enumerate(tqdm(trainloader), 0):
            steps += 1
            inputs, labels = inputs.to(
                device=args.device,
                dtype=args.dtype), labels.to(device=args.device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            accs.append(metrics.logit2acc(
                outputs.data,
                labels))  # probably a bad way to calculate accuracy
            training_loss += loss.item()

        logger.add(epoch, tr_loss=training_loss / steps, tr_acc=np.mean(accs))

        # Deterministic test
        net.eval()
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        logger.add(epoch, te_nll_det=nll, te_acc_det=acc)

        # Stochastic test
        net.train()
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        logger.add(epoch, te_nll_stoch=nll, te_acc_stoch=acc)

        # Test-time averaging
        net.train()
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=20)
        logger.add(epoch, te_nll_ens=nll, te_acc_ens=acc)

        # Zero-mean
        net.train()
        net.dense1.set_flag('zero_mean', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        net.dense1.set_flag('zero_mean', False)
        logger.add(epoch, te_nll_zero_mean=nll, te_acc_zero_mean=acc)

        # Permuted sigmas
        net.train()
        net.dense1.set_flag('permute_sigma', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        net.dense1.set_flag('permute_sigma', False)
        logger.add(epoch, te_nll_perm_sigma=nll, te_acc_perm_sigma=acc)

        # Zero-mean test-time averaging
        net.train()
        net.dense1.set_flag('zero_mean', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=20)
        net.dense1.set_flag('zero_mean', False)
        logger.add(epoch, te_nll_zero_mean_ens=nll, te_acc_zero_mean_ens=acc)

        # Permuted sigmas test-time averaging
        net.train()
        net.dense1.set_flag('permute_sigma', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=20)
        net.dense1.set_flag('permute_sigma', False)
        logger.add(epoch, te_nll_perm_sigma_ens=nll, te_acc_perm_sigma_ens=acc)

        logger.add(epoch, time=time() - t0)
        las = [
            np.mean(net.conv1.log_alpha.data.cpu().numpy()),
            np.mean(net.conv2.log_alpha.data.cpu().numpy()),
            np.mean(net.dense1.log_alpha.data.cpu().numpy()),
            np.mean(net.dense2.log_alpha.data.cpu().numpy())
        ]

        logger.add(epoch, **{'la%d' % i: las[i] for i in range(4)})
        logger.iter_info()
        logger.save(silent=True)
        torch.save(net.state_dict(), logger.checkpoint)

    logger.save()