예제 #1
0
class Experiment(object):
    def __init__(self,
                 directory,
                 epochs=1,
                 cuda=False,
                 save=False,
                 log_interval=30,
                 load=None,
                 split=(0.6, 0.2, 0.2),
                 cache=False,
                 minibatch_size=10,
                 pretrained=False):
        self.dataset = Dataset(directory,
                               split=split,
                               cache=cache,
                               minibatch_size=minibatch_size)
        self.epochs = epochs
        self.cuda = cuda
        self.save = save
        self.log_interval = log_interval
        self.model = DenseNet(pretrained)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
        if load is not None:
            state = torch.load(load)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optim'])
        if cuda:
            self.model = self.model.cuda()

    def train(self):
        print('Training %s epochs.' % self.epochs)
        loss_fun = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                               'min',
                                                               verbose=True,
                                                               patience=3)
        last_print = time.time()
        for epoch in range(self.epochs):
            tprint('Starting epoch: %s' % epoch)
            self.model.train()
            self.optimizer.zero_grad()
            for minibatch, targets in self.dataset.train:
                minibatch = Variable(torch.stack(minibatch))
                targets = Variable(torch.LongTensor(targets))
                if self.cuda:
                    minibatch = minibatch.cuda()
                    targets = targets.cuda()
                out = self.model.forward(minibatch)
                loss = loss_fun(out, targets)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if time.time() - last_print > self.log_interval:
                    last_print = time.time()
                    numer, denom = self.dataset.train.progress()
                    tprint('Training: %s, %s/%s' % (epoch, numer, denom))
            tprint('Training complete. Beginning validation.')
            self.dataset.train.reload()
            self.model.eval()
            last_print = time.time()
            for minibatch, targets in self.dataset.validate:
                minibatch = Variable(torch.stack(minibatch), volatile=True)
                targets = Variable(torch.LongTensor(targets), volatile=True)
                if self.cuda:
                    minibatch = minibatch.cuda()
                    targets = targets.cuda()
                out = self.model.forward(minibatch)
                validation_loss = loss_fun(out, targets)
                if time.time() - last_print > self.log_interval:
                    last_print = time.time()
                    numer, denom = self.dataset.validate.progress()
                    tprint('Validating: %s, %s/%s' % (epoch, numer, denom))
            self.dataset.validate.reload()
            scheduler.step(validation_loss.data[0])
        if self.save:
            torch.save(
                {
                    'model': self.model.state_dict(),
                    'optim': self.optimizer.state_dict(),
                }, 'signet.%s.pth' % int(time.time()))

    def test(self):
        tprint('Beginning testing.')
        confusion_matrix = np.zeros((7, 7)).astype(np.int)
        last_print = time.time()
        for minibatch, targets in self.dataset.test:
            minibatch = Variable(torch.stack(minibatch), volatile=True)
            targets = Variable(torch.LongTensor(targets), volatile=True)
            if self.cuda:
                minibatch = minibatch.cuda()
                targets = targets.cuda()
            out = self.model.forward(minibatch)
            _, predicted = torch.max(out.data, 1)
            predicted = predicted.cpu().numpy()
            targets = targets.data.cpu().numpy()
            confusion_matrix += sklearn.metrics.confusion_matrix(
                predicted, targets, labels=[0, 1, 2, 3, 4, 5,
                                            6]).astype(np.int)
            if time.time() - last_print > self.log_interval:
                last_print = time.time()
                numer, denom = self.dataset.test.progress()
                tprint('Testing: %s/%s' % (numer, denom))
        tprint('Testing complete.')
        print(confusion_matrix)
        print(tabulate.tabulate(stats(confusion_matrix)))
예제 #2
0
    model.eval()
    validation_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = Variable(data.cuda()), Variable(target.cuda())
            output = model(data)
            validation_loss += F.cross_entropy(
                output, target,
                size_average=False).data[0]  # sum up batch loss
            pred = output.data.max(
                1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    validation_loss /= len(val_loader.dataset)
    print(
        '\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.
        format(validation_loss, correct, len(val_loader.dataset),
               100.0 * float(correct) / len(val_loader.dataset)))


for epoch in range(1, args.epochs + 1):
    train(epoch)
    validation()
    model_file = 'model_' + str(epoch) + '.pth'
    torch.save(model.state_dict(), model_file)
    print('\nSaved model to ' + model_file +
          '. You can run `python evaluate.py ' + model_file +
          '` to generate the Kaggle formatted csv file')
예제 #3
0
class Solver(object):

    DEFAULTS = {}

    def __init__(self, version, data_loader, config):
        """
        Initializes a Solver object
        """

        # data loader
        self.__dict__.update(Solver.DEFAULTS, **config)
        self.version = version
        self.data_loader = data_loader

        self.build_model()

        # TODO: build tensorboard

        # start with a pre-trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        """
        Instantiates the model, loss criterion, and optimizer
        """

        # instantiate model
        self.model = DenseNet(config=self.config,
                              channels=self.input_channels,
                              class_count=self.class_count,
                              num_features=self.num_features,
                              compress_factor=self.compress_factor,
                              expand_factor=self.expand_factor,
                              growth_rate=self.growth_rate)

        # instantiate loss criterion
        self.criterion = nn.CrossEntropyLoss()

        # instantiate optimizer
        self.optimizer = optim.SGD(params=self.model.parameters(),
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay,
                                   nesterov=True)

        # print network
        self.print_network(self.model, 'DenseNet')

        # use gpu if enabled
        if torch.cuda.is_available() and self.use_gpu:
            self.model.cuda()
            self.criterion.cuda()

    def print_network(self, model, name):
        """
        Prints the structure of the network and the total number of parameters
        """
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        """
        loads a pre-trained model from a .pth file
        """
        self.model.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}.pth'.format(self.pretrained_model))))
        print('loaded trained model ver {}'.format(self.pretrained_model))

    def print_loss_log(self, start_time, iters_per_epoch, e, i, loss):
        """
        Prints the loss and elapsed time for each epoch
        """
        total_iter = self.num_epochs * iters_per_epoch
        cur_iter = e * iters_per_epoch + i

        elapsed = time.time() - start_time
        total_time = (total_iter - cur_iter) * elapsed / (cur_iter + 1)
        epoch_time = (iters_per_epoch - i) * elapsed / (cur_iter + 1)

        epoch_time = str(datetime.timedelta(seconds=epoch_time))
        total_time = str(datetime.timedelta(seconds=total_time))
        elapsed = str(datetime.timedelta(seconds=elapsed))

        log = "Elapsed {}/{} -- {}, Epoch [{}/{}], Iter [{}/{}], " \
              "loss: {:.4f}".format(elapsed,
                                    epoch_time,
                                    total_time,
                                    e + 1,
                                    self.num_epochs,
                                    i + 1,
                                    iters_per_epoch,
                                    loss)

        # TODO: add tensorboard

        print(log)

    def save_model(self, e):
        """
        Saves a model per e epoch
        """
        path = os.path.join(
            self.model_save_path,
            '{}/{}.pth'.format(self.version, e + 1)
        )

        torch.save(self.model.state_dict(), path)

    def model_step(self, images, labels):
        """
        A step for each iteration
        """

        # set model in training mode
        self.model.train()

        # empty the gradients of the model through the optimizer
        self.optimizer.zero_grad()

        # forward pass
        output = self.model(images)

        # compute loss
        loss = self.criterion(output, labels.squeeze())

        # compute gradients using back propagation
        loss.backward()

        # update parameters
        self.optimizer.step()

        # return loss
        return loss

    def train(self):
        """
        Training process
        """
        self.losses = []
        self.top_1_acc = []
        self.top_5_acc = []

        iters_per_epoch = len(self.data_loader)

        # start with a trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('/')[-1])
        else:
            start = 0

        # start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (images, labels) in enumerate(tqdm(self.data_loader)):
                images = to_var(images, self.use_gpu)
                labels = to_var(labels, self.use_gpu)

                loss = self.model_step(images, labels)

            # print out loss log
            if (e + 1) % self.loss_log_step == 0:
                self.print_loss_log(start_time, iters_per_epoch, e, i, loss)
                self.losses.append((e, loss))

            # save model
            if (e + 1) % self.model_save_step == 0:
                self.save_model(e)

            # evaluate on train dataset
            if (e + 1) % self.train_eval_step == 0:
                top_1_acc, top_5_acc = self.train_evaluate(e)
                self.top_1_acc.append((e, top_1_acc))
                self.top_5_acc.append((e, top_5_acc))

        # print losses
        print('\n--Losses--')
        for e, loss in self.losses:
            print(e, '{:.4f}'.format(loss))

        # print top_1_acc
        print('\n--Top 1 accuracy--')
        for e, acc in self.top_1_acc:
            print(e, '{:.4f}'.format(acc))

        # print top_5_acc
        print('\n--Top 5 accuracy--')
        for e, acc in self.top_5_acc:
            print(e, '{:.4f}'.format(acc))

    def eval(self, data_loader):
        """
        Returns the count of top 1 and top 5 predictions
        """

        # set the model to eval mode
        self.model.eval()

        top_1_correct = 0
        top_5_correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in data_loader:

                images = to_var(images, self.use_gpu)
                labels = to_var(labels, self.use_gpu)

                output = self.model(images)
                total += labels.size()[0]

                # top 1
                # get the max for each instance in the batch
                _, top_1_output = torch.max(output.data, dim=1)

                top_1_correct += torch.sum(torch.eq(labels.squeeze(),
                                                    top_1_output))

                # top 5
                _, top_5_output = torch.topk(output.data, k=5, dim=1)
                for i, label in enumerate(labels):
                    if label in top_5_output[i]:
                        top_5_correct += 1

        return top_1_correct.item(), top_5_correct, total

    def train_evaluate(self, e):
        """
        Evaluates the performance of the model using the train dataset
        """
        top_1_correct, top_5_correct, total = self.eval(self.data_loader)
        log = "Epoch [{}/{}]--top_1_acc: {:.4f}--top_5_acc: {:.4f}".format(
            e + 1,
            self.num_epochs,
            top_1_correct / total,
            top_5_correct / total
        )
        print(log)
        return top_1_correct / total, top_5_correct / total

    def test(self):
        """
        Evaluates the performance of the model using the test dataset
        """
        top_1_correct, top_5_correct, total = self.eval(self.data_loader)
        log = "top_1_acc: {:.4f}--top_5_acc: {:.4f}".format(
            top_1_correct / total,
            top_5_correct / total
        )
        print(log)
예제 #4
0
class Trainer(object):
    """
    The Trainer class encapsulates all the logic necessary for 
    training the DenseNet model. It use SGD to update the weights 
    of the model given hyperparameters constraints provided by the 
    user in the config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Params
        ------
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
        else:
            self.test_loader = data_loader

        # network params
        self.num_blocks = config.num_blocks
        self.num_layers_total = config.num_layers_total
        self.growth_rate = config.growth_rate
        self.bottleneck = config.bottleneck
        self.theta = config.compression

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.best_valid_acc = 0.
        self.init_lr = config.init_lr
        self.lr = self.init_lr
        self.is_decay = True
        self.momentum = config.momentum
        self.weight_decay = config.weight_decay
        self.dropout_rate = config.dropout_rate
        if config.lr_sched == '':
            self.is_decay = False
        else:
            self.lr_decay = [float(x) for x in config.lr_sched.split(',')]

        # other params
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.num_gpu = config.num_gpu
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.dataset = config.dataset
        if self.dataset == 'cifar10':
            self.num_classes = 10
        elif self.dataset == 'cifar100':
            self.num_classes = 100
        else:
            self.num_classes = 1000

        # build densenet model
        self.model = DenseNet(self.num_blocks, self.num_layers_total,
            self.growth_rate, self.num_classes, self.bottleneck, 
                self.dropout_rate, self.theta)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # define loss and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.init_lr,
                momentum=self.momentum, weight_decay=self.weight_decay)

        if self.num_gpu > 0:
            self.model.cuda()
            self.criterion.cuda()

        # finally configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.get_model_name()
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

    def train(self):
        """
        Train the model on the training set. 

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # switch to train mode for dropout
        self.model.train()

        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        for epoch in trange(self.start_epoch, self.epochs):
            
            # decay learning rate
            if self.decay:
                self.anneal_learning_rate(epoch)

            # train for 1 epoch
            self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_acc = self.validate(epoch)

            is_best = valid_acc > self.best_valid_acc
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'best_valid_acc': self.best_valid_acc}, is_best)

    def test(self):
        """
        Test the model on the held-out test data. 

        This function should only be called at the very
        end once the model has finished training.
        """
        # switch to test mode for dropout
        self.model.eval()

        accs = AverageMeter()
        batch_time = AverageMeter()

        # load the best checkpoint
        self.load_checkpoint(best=True)

        tic = time.time()
        for i, (image, target) in enumerate(self.test_loader):
            if self.num_gpu > 0:
                image = image.cuda()
                target = target.cuda(async=True)
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

            # forward pass
            output = self.model(input_var)

            # compute loss & accuracy 
            acc = self.accuracy(output.data, target)
            accs.update(acc, image.size()[0])

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

            # print to screen
            if i % self.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Test Acc: {acc.val:.3f} ({acc.avg:.3f})'.format(
                        i, len(self.test_loader), batch_time=batch_time,
                        acc=accs))

        print('[*] Test Acc: {acc.avg:.3f}'.format(acc=accs))

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set. 

        An epoch corresponds to one full pass through the entire 
        training set in successive mini-batches. 

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        for i, (image, target) in enumerate(self.train_loader):
            if self.num_gpu > 0:
                image = image.cuda()
                target = target.cuda(async=True)
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

            # forward pass
            output = self.model(input_var)

            # compute loss & accuracy 
            loss = self.criterion(output, target_var)
            acc = self.accuracy(output.data, target)
            losses.update(loss.data[0], image.size()[0])
            accs.update(acc, image.size()[0])

            # compute gradients and update SGD
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

            # print to screen
            if i % self.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Train Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Train Acc: {acc.val:.3f} ({acc.avg:.3f})'.format(
                        epoch, i, len(self.train_loader), batch_time=batch_time,
                        loss=losses, acc=accs))

        # log to tensorboard
        if self.use_tensorboard:
            log_value('train_loss', losses.avg, epoch)
            log_value('train_acc', accs.avg, epoch)


    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        for i, (image, target) in enumerate(self.valid_loader):
            if self.num_gpu > 0:
                image = image.cuda()
                target = target.cuda(async=True)
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

            # forward pass
            output = self.model(input_var)

            # compute loss & accuracy 
            loss = self.criterion(output, target_var)
            acc = self.accuracy(output.data, target)
            losses.update(loss.data[0], image.size()[0])
            accs.update(acc, image.size()[0])

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

            # print to screen
            if i % self.print_freq == 0:
                print('Valid: [{0}/{1}]\t'
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Valid Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Valid Acc: {acc.val:.3f} ({acc.avg:.3f})'.format(
                        i, len(self.valid_loader), batch_time=batch_time,
                        loss=losses, acc=accs))

        print('[*] Valid Acc: {acc.avg:.3f}'.format(acc=accs))

        # log to tensorboard
        if self.use_tensorboard:
            log_value('val_loss', losses.avg, epoch)
            log_value('val_acc', accs.avg, epoch)

        return accs.avg

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated 
        on the test data.

        Furthermore, the model with the highest accuracy is saved as
        with a special name.
        """
        print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.get_model_name() + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.get_model_name() + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, 
                os.path.join(self.ckpt_dir, filename))
            print("[*] ==== Best Valid Acc Achieved ====")

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in 
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.get_model_name() + '_ckpt.pth.tar'
        if best:
            filename = self.get_model_name() + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['state_dict'])
        
        print("[*] Loaded {} checkpoint @ epoch {} with best valid acc of {:.3f}".format(
                    filename, ckpt['epoch'], ckpt['best_valid_acc']))

    def anneal_learning_rate(self, epoch):
        """
        This function decays the learning rate at 2 instances.

        - The initial learning rate is divided by 10 at
          t1*epochs.
        - It is further divided by 10 at t2*epochs. 

        t1 and t2 are floats specified by the user. The default
        values used by the authors of the paper are 0.5 and 0.75.
        """
        sched1 = int(self.lr_sched[0] * self.epochs)
        sched2 = int(self.lr_sched[1] * self.epochs)

        self.lr = self.init_lr * (0.1 ** (epoch // sched1)) \
                               * (0.1 ** (epoch // sched2))

        # log to tensorboard
        if self.use_tensorboard:
            log_value('learning_rate', self.lr, epoch)

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def get_model_name(self):
        """
        Returns the name of the model based on the configuration
        parameters.

        The name will take the form DenseNet-X-Y-Z where:

        - X: total number of layers specified by `config.total_num_layers`.
        - Y: can be BC or an empty string specified by `config.bottleneck`.
        - Z: name of the dataset specified by `config.dataset`.

        For example, given 169 layers with bottleneck on CIFAR-10, this 
        function will output `DenseNet-BC-169-cifar10`.
        """
        if self.bottleneck:
            return 'DenseNet-BC-{}-{}'.format(self.num_layers_total,
                self.dataset)
        return 'DenseNet-{}-{}'.format(self.num_layers_total,
            self.dataset)

    def accuracy(self, predicted, ground_truth):
        """
        Utility function for calculating the accuracy of the model.

        Params
        ------
        - predicted: (torch.FloatTensor)
        - ground_truth: (torch.LongTensor)

        Returns
        -------
        - acc: (float) % accuracy.
        """
        predicted = torch.max(predicted, 1)[1]
        total = len(ground_truth)
        correct = (predicted == ground_truth).sum()
        acc = 100 * (correct / total)
        return acc
예제 #5
0
class TetrisQLearn:
    def __init__(self, games_state, savename, dirname='logs', **kwargs):

        self.simulator = games_state

        # Q learn basic params
        self.explore_val = 1  # probability to explore v exploit
        self.explore_decay = 0.999  # explore chance is reduced as Q resolves
        self.gamma = 1  # short-term/long-term trade-off param
        self.num_episodes = 500  # number of episodes of simulation to perform
        self.save_weight_freq = 10  # controls how often (in number of episodes) the weights of Q are saved
        self.memory = []
        self._process_mask = []
        self.processed_memory = []  # memory container

        # fitted Q-Learning params
        self.episode_update = 1  # after how many episodes should we update Q?
        self.batch_size = 10  # length of memory replay (in episodes)

        self.schedule = False
        self.refresh_target = 1

        self.renderpath = None

        # let user define each of the params above
        if "gamma" in kwargs:
            self.gamma = kwargs['gamma']
        if 'explore_val' in kwargs:
            self.explore_val = kwargs['explore_val']
        if 'explore_decay' in kwargs:
            self.explore_decay = kwargs['explore_decay']
        if 'num_episodes' in kwargs:
            self.num_episodes = kwargs['num_episodes']
        if 'episode_update' in kwargs:
            self.episode_update = kwargs['episode_update']
        if 'exit_level' in kwargs:
            self.exit_level = kwargs['exit_level']
        if 'exit_window' in kwargs:
            self.exit_window = kwargs['exit_window']
        if 'save_weight_freq' in kwargs:
            self.save_weight_freq = kwargs['save_weight_freq']
        if 'batch_size' in kwargs:
            self.batch_size = kwargs['batch_size']
        if 'episode_update' in kwargs:
            self.episode_update = kwargs['episode_update']
        if 'schedule' in kwargs:
            self.schedule = kwargs['schedule']
        if 'memory_length' in kwargs:
            self.memory_length = kwargs['memory_length']
        if 'refresh_target' in kwargs:
            self.refresh_target = kwargs['refresh_target']
        if 'minibatch_size' in kwargs:
            self.minibatch_size = kwargs['minibatch_size']
        if 'render_path' in kwargs:
            self.renderpath = kwargs['render_path']
        if 'use_target' in kwargs:
            self.use_target = kwargs['use_target']

        # get simulation-specific variables from simulator
        self.num_actions = self.simulator.output_dimension
        self.training_reward = []
        self.savename = savename

        # create text file for training log
        self.logname = os.path.join(dirname, 'training_logs',
                                    savename + '.txt')
        self.reward_logname = os.path.join(dirname, 'reward_logs',
                                           savename + '.txt')
        if not os.path.exists(
                os.path.join(dirname, 'saved_model_weights', savename)):
            os.mkdir(os.path.join(dirname, 'saved_model_weights', savename))
        if self.renderpath and not os.path.exists(
                os.path.join(self.renderpath, self.savename)):
            os.makedirs(os.path.join(self.renderpath, self.savename),
                        exist_ok=True)
        self.renderpath = os.path.join(self.renderpath, self.savename)

        self.weights_folder = os.path.join(dirname, 'saved_model_weights',
                                           savename)
        self.reward_table = os.path.join(dirname, 'reward_logs_extended',
                                         savename + '.csv')
        self.weights_idx = 0

        self.init_log(self.logname)
        self.init_log(self.reward_logname)
        self.init_log(self.reward_table)

        self.write_header = True

    def render_model(self, epoch):
        fig = plt.figure()
        ax = fig.gca(projection='3d')

        print('rendering...')
        tick = time.time()
        axis_extents = self.simulator.board_extents
        ax.set_xlim3d(0, axis_extents[0])
        ax.set_ylim3d(0, axis_extents[1])
        ax.set_zlim3d(0, axis_extents[2])
        demo_game = GameState(board_shape=axis_extents,
                              rewards=self.simulator.rewards)
        self.model.eval()
        path = os.path.join(self.renderpath,
                            self.savename + '-' + str(epoch) + '.gif')
        render.render_from_model(self.model,
                                 fig,
                                 ax,
                                 demo_game,
                                 path,
                                 device=self.device)
        self.model.train()
        print('rendered %s in %.2fs' % (path, time.time() - tick))

    # Logging stuff
    def init_log(self, logname):
        # delete log if old version exists
        if os.path.exists(logname):
            os.remove(logname)

    def update_log(self, logname, update, epoch=None):
        if type(update) == str:
            logfile = open(logname, "a")
            logfile.write(update)
            logfile.close()
        else:
            mod = self.model.state_dict()
            torch.save(
                mod,
                os.path.join(self.weights_folder,
                             self.savename + str(epoch) + '.pth'))
            self.weights_idx += 1

    def log_reward(self, reward_dict):

        keys = reward_dict.keys()

        if self.write_header:
            with open(self.reward_table, 'a') as output_file:
                dict_writer = csv.DictWriter(output_file, keys)
                dict_writer.writeheader()
            self.write_header = False

        with open(self.reward_table, 'a') as output_file:
            dict_writer = csv.DictWriter(output_file, keys)
            dict_writer.writerow(reward_dict)

    # Q Learning Stuff
    def initialize_Q(self, model_path=None, alpha=None, **kwargs):
        lr = 10**(-2)
        if alpha:
            lr = alpha

        # Input/Output size fot the network
        output_dim = self.num_actions

        self.model = DenseNet(output_dim, **kwargs)
        if model_path is not None:
            print('loading check point %s' % model_path)
            self.model.load_state_dict(torch.load(model_path))
        self.Q = self.model
        if self.use_target:
            self.target_network = self.Q.copy()
        else:
            self.target_network = self.Q

        self.use_cuda = False
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.use_cuda = True
            self.model.to(torch.device('cuda:0'))
            self.device = torch.device('cuda:0')

        self.target_network.to(self.device)
        self.model.to(self.device)

        self.loss = nn.SmoothL1Loss()
        self.max_lr = lr
        self.optimizer = optim.RMSprop(self.model.parameters(), lr)
        if self.schedule:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=10, verbose=True)
        for p in self.model.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -1, 1))

    def _embed_piece(self, piece, embedding_size=(4, 4, 4)):
        out = np.zeros((piece.shape[0], piece.shape[1], *embedding_size))
        for coord in product(*[range(x) for x in piece.shape]):
            out[coord] = piece[coord]
        return torch.from_numpy(out).to(self.device)

    def memory_replay(self):
        # process transitions using target network
        q_vals = []
        states = []
        pieces = []
        locats = []
        episode_loss = 0

        total_processed = 0

        tick = time.time()
        for i in range(len(self.memory)):
            episode_data = self.memory[i]
            if self.processed_memory[i] is None:
                self.processed_memory[i] = [None] * len(episode_data)

            for j in range(len(episode_data)):
                # process the sample and put it into the processed_memory
                sample = episode_data[j]

                state, piece, locat = sample[0]
                next_state, next_piece, next_locat = sample[1]
                action = sample[2]
                reward = sample[3]
                done = sample[4]

                if self.processed_memory[i][j] is None:
                    q = reward

                    # preprocess q using target network
                    if not done:
                        next_state = torch.tensor(next_state).to(self.device)
                        next_piece = torch.tensor(next_piece).to(self.device)
                        next_locat = torch.tensor(next_locat).to(self.device)

                        qs = self.target_network(next_state, next_piece,
                                                 next_locat)  #should be target
                        q += self.gamma * torch.max(qs)

                    state = torch.tensor(state).to(self.device)
                    piece = torch.tensor(piece).to(self.device)
                    locat = torch.tensor(locat).to(self.device)

                    # q is our experientially validated move score. Anchor it on our prediction vector
                    q_update = self.target_network(
                        state, piece, locat).squeeze(0)  # should be target
                    q_update[action] = q
                    processed = q_update.detach().cpu().numpy()
                    q_vals = q_vals + [processed]
                    self.processed_memory[i][j] = processed
                    total_processed += 1

                    ## WE HAVE NOW PREPROCESSED W TARGET NETWORK

                    # clear up the vram
                    state = state.cpu().squeeze(0).numpy()
                    piece = piece.float().squeeze(0).cpu().numpy()
                    locat = locat.float().cpu().numpy()

                else:
                    q_vals = q_vals + [self.processed_memory[i][j]]

                # its goofy but it will work
                if state.ndim > 4:
                    state = state.squeeze(0)
                assert state.ndim == 4

                if piece.ndim > 4:
                    piece = piece.squeeze(0)
                assert piece.ndim == 4

                if locat.ndim > 2:
                    locat = locat.squeeze(0)
                assert locat.ndim == 2

                states.append(state)
                pieces.append(piece)
                locats.append(locat)

        elapsed_time = time.time() - tick
        print('process time: %.2f' % elapsed_time)
        print('total processed: %d (%.2f/s)' %
              (total_processed, total_processed / elapsed_time))

        # take descent step
        memory = MemoryDset(states, pieces, locats, q_vals)

        if self.minibatch_size > 0:
            ids = random.sample(range(len(memory)),
                                min(self.minibatch_size,
                                    len(memory) - 1))
            memory = torch.utils.data.Subset(memory, ids)
        dataloader = DataLoader(memory,
                                batch_size=self.batch_size,
                                shuffle=True)

        tick = time.time()
        for s, p, l, q in dataloader:
            self.optimizer.zero_grad()

            out = self.Q(s, p, l)
            loss = self.loss(out, q)
            loss.backward()
            self.optimizer.step()
            episode_loss += loss.item()
            s.detach_()
        print('fit time: %.2f' % (time.time() - tick))

        if self.schedule:
            self.scheduler.step(episode_loss)
        return episode_loss / len(self.memory) / len(dataloader)

    def update_target(self):
        if self.use_target:
            self.target_network = self.Q.copy()
            self.target_network.to(self.device)
        self.processed_memory = [None] * len(self.processed_memory)

    def update_memory(self, episode_data):
        # add most recent trial data to memory
        self.memory.append(episode_data)
        self.processed_memory.append(None)

        # clip memory if it gets too long
        num_episodes = len(self.memory)
        if num_episodes >= self.memory_length:
            num_delete = num_episodes - self.memory_length
            self.memory[:num_delete] = []
            self.processed_memory[:num_delete] = []

    def make_torch(self, array):
        tens = torch.from_numpy(array.copy())
        tens = tens.float()
        tens = tens.unsqueeze(0)
        tens = tens.unsqueeze(0)
        return tens.to(self.device)

    # choose next action
    def choose_action(self, state, piece, location):
        # pick action at random
        p = np.random.rand(1)

        action = np.random.randint(len(self.simulator.action_space))

        # pick action based on exploiting
        qs = self.Q(state, piece, location)

        if p > self.explore_val:
            action = torch.argmax(qs)
        return action

    def renormalize_vec(self, tensor, idx):
        tensor = tensor.squeeze(0)
        tensor[idx] = 0
        sum = torch.sum(tensor)
        return tensor / sum

    def run(self):

        print("num_episodes: %s" % self.num_episodes)

        # start main Q-learning loop
        for n in range(self.num_episodes):
            # pick this episode's starting position
            state = self.simulator.reset()
            total_episode_reward = 0
            done = False

            # get our exploit parameter for this episode
            if self.explore_val > 0.01 and (n % self.episode_update) == 0:
                old_explore = self.explore_val
                self.explore_val *= self.explore_decay
                if old_explore - self.explore_val > 0.25:
                    for param in self.optimizer.param_groups:
                        print('resetting to max learning rate: %s' %
                              self.max_lr)
                        param['lr'] = self.max_lr

            # run episode
            step = 0
            episode_data = []
            ep_start_time = time.time()
            ep_rew_dict = None

            while done is False:

                # choose next action
                board = self.make_torch(state.board)
                piece = self.make_torch(state.current.matrix)
                loc = torch.tensor(state.current.location).unsqueeze(0).to(
                    self.device)
                action = self.choose_action(board, piece, loc)

                # transition to next state, get associated reward
                next_state, reward_dict, done = self.simulator(
                    self.simulator.action_space[action])
                if ep_rew_dict is None:
                    ep_rew_dict = reward_dict
                else:
                    ep_rew_dict = add_reward_dicts(ep_rew_dict, reward_dict)
                next_board = self.make_torch(next_state.board)
                next_piece = self.make_torch(next_state.current.matrix)
                next_locat = torch.tensor(
                    next_state.current.location).unsqueeze(0)

                reward = reward_dict['total']

                # move board back to cpu to clear up vram
                board = board.cpu().numpy()
                next_board = next_board.cpu().numpy()
                piece = piece.cpu().numpy()
                location = loc.cpu().numpy()
                next_piece = next_piece.cpu().numpy()
                next_locat = next_locat.cpu().numpy()

                # store data for transition after episode ends
                episode_data.append([(board, piece, location),
                                     (next_board, next_piece, next_locat),
                                     action, reward, done])

                # update total reward from this episode
                total_episode_reward += reward
                state = copy.deepcopy(next_state)
                step += 1

            # update memory with this episode's data
            self.update_memory(episode_data)

            LOSS_SCALING = 100

            # update the target network
            if self.use_target:
                if np.mod(n, self.refresh_target) == 0:
                    self.update_target()
            else:
                self.update_target()

            # train model
            episode_loss = 0
            if np.mod(n, self.episode_update) == 0:
                episode_loss = self.memory_replay() * LOSS_SCALING

            # update episode reward greater than exit_level, add to counter
            exit_ave = total_episode_reward
            if n >= self.exit_window:
                exit_ave = np.sum(
                    np.array(self.training_reward[-self.exit_window:])
                ) / self.exit_window

            # print out updates
            # I abuse the f**k out of this variable. Watch how many different values it assumes and how
            # important the order of operations is. I do this because I hate myself.
            update = 'episode ' + str(n + 1) + ' of ' + str(
                self.num_episodes
            ) + ' complete, ' + 'loss x%s = ' % LOSS_SCALING + str(
                np.round(episode_loss, 3)) + ' explore val = ' + str(
                    np.round(self.explore_val, 3)
                ) + ', episode reward = ' + str(
                    np.round(
                        total_episode_reward, 1)) + ', ave reward = ' + str(
                            np.round(exit_ave, 3)) + ', episode_time = ' + str(
                                np.round(time.time() - ep_start_time, 3))

            self.update_log(self.logname, update + '\n')

            if np.mod(n, self.episode_update) == 0:
                print(colored(update, 'red'))
            else:
                print(update)

            # save latest weights from this episode
            if np.mod(n, self.save_weight_freq) == 0:
                update = self.model.state_dict()
                self.update_log(self.weights_folder, update, epoch=n)

            if self.renderpath and n % self.save_weight_freq == self.save_weight_freq - 1:
                self.render_model(n + 1)

            update = str(total_episode_reward) + '\n'
            self.update_log(self.reward_logname, update)
            self.log_reward(ep_rew_dict)

            # store this episode's computation time and training reward history
            self.training_reward.append(total_episode_reward)

        update = 'q-learning algorithm complete'
        self.update_log(self.logname, update + '\n')
        print(update)