Пример #1
0
def main(args):

    # cfg_file = os.path.join(args.example_config_path, args.primitive) + ".yaml"
    cfg = get_vae_defaults()
    # cfg.merge_from_file(cfg_file)
    cfg.freeze()

    batch_size = args.batch_size
    dataset_size = args.total_data_size

    if args.experiment_name is None:
        experiment_name = args.model_name
    else:
        experiment_name = args.experiment_name

    if not os.path.exists(os.path.join(args.log_dir, experiment_name)):
        os.makedirs(os.path.join(args.log_dir, experiment_name))

    description_txt = raw_input('Please enter experiment notes: \n')
    if isinstance(description_txt, str):
        with open(
                os.path.join(args.log_dir, experiment_name,
                             experiment_name + '_description.txt'), 'wb') as f:
            f.write(description_txt)

    writer = SummaryWriter(os.path.join(args.log_dir, experiment_name))

    # torch_seed = np.random.randint(low=0, high=1000)
    # np_seed = np.random.randint(low=0, high=1000)
    torch_seed = 0
    np_seed = 0

    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)

    trained_model_path = os.path.join(args.model_path, args.model_name)
    if not os.path.exists(trained_model_path):
        os.makedirs(trained_model_path)

    if args.task == 'contact':
        if args.start_rep == 'keypoints':
            start_dim = 24
        elif args.start_rep == 'pose':
            start_dim = 7

        if args.goal_rep == 'keypoints':
            goal_dim = 24
        elif args.goal_rep == 'pose':
            goal_dim = 7

        if args.skill_type == 'pull':
            # + 7 because single arm palm pose
            input_dim = start_dim + goal_dim + 7
        else:
            # + 14 because both arms palm pose
            input_dim = start_dim + goal_dim + 14
        output_dim = 7
        decoder_input_dim = start_dim + goal_dim

        vae = VAE(input_dim,
                  output_dim,
                  args.latent_dimension,
                  decoder_input_dim,
                  hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP,
                  lr=args.learning_rate)
    elif args.task == 'goal':
        if args.start_rep == 'keypoints':
            start_dim = 24
        elif args.start_rep == 'pose':
            start_dim = 7

        if args.goal_rep == 'keypoints':
            goal_dim = 24
        elif args.goal_rep == 'pose':
            goal_dim = 7

        input_dim = start_dim + goal_dim
        output_dim = goal_dim
        decoder_input_dim = start_dim
        vae = GoalVAE(input_dim,
                      output_dim,
                      args.latent_dimension,
                      decoder_input_dim,
                      hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP,
                      lr=args.learning_rate)
    elif args.task == 'transformation':
        input_dim = args.input_dimension
        output_dim = args.output_dimension
        decoder_input_dim = args.input_dimension - args.output_dimension
        vae = GoalVAE(input_dim,
                      output_dim,
                      args.latent_dimension,
                      decoder_input_dim,
                      hidden_layers=cfg.ENCODER_HIDDEN_LAYERS_MLP,
                      lr=args.learning_rate)
    else:
        raise ValueError('training task not recognized')

    if torch.cuda.is_available():
        vae.encoder.cuda()
        vae.decoder.cuda()

    if args.start_epoch > 0:
        start_epoch = args.start_epoch
        num_epochs = args.num_epochs
        fname = os.path.join(
            trained_model_path,
            args.model_name + '_epoch_%d.pt' % args.start_epoch)
        torch_seed, np_seed = load_seed(fname)
        load_net_state(vae, fname)
        load_opt_state(vae, fname)
        args = load_args(fname)
        args.start_epoch = start_epoch
        args.num_epochs = num_epochs
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)

    data_dir = args.data_dir
    data_loader = DataLoader(data_dir=data_dir)

    data_loader.create_random_ordering(size=dataset_size)

    dataset = data_loader.load_dataset(start_rep=args.start_rep,
                                       goal_rep=args.goal_rep,
                                       task=args.task)

    total_loss = []
    start_time = time.time()
    print('Saving models to: ' + trained_model_path)
    kl_weight = 1.0
    print('Starting on epoch: ' + str(args.start_epoch))

    for epoch in range(args.start_epoch, args.start_epoch + args.num_epochs):
        print('Epoch: ' + str(epoch))
        epoch_total_loss = 0
        epoch_kl_loss = 0
        epoch_pos_loss = 0
        epoch_ori_loss = 0
        epoch_recon_loss = 0
        kl_coeff = 1 - kl_weight
        kl_weight = args.kl_anneal_rate * kl_weight
        print('KL coeff: ' + str(kl_coeff))
        for i in range(0, dataset_size, batch_size):
            vae.optimizer.zero_grad()

            input_batch, decoder_input_batch, target_batch = \
                data_loader.sample_batch(dataset, i, batch_size)
            input_batch = to_var(torch.from_numpy(input_batch))
            decoder_input_batch = to_var(torch.from_numpy(decoder_input_batch))

            z, recon_mu, z_mu, z_logvar = vae.forward(input_batch,
                                                      decoder_input_batch)
            kl_loss = vae.kl_loss(z_mu, z_logvar)

            if args.task == 'contact':
                output_r, output_l = recon_mu
                if args.skill_type == 'grasp':
                    target_batch_right = to_var(
                        torch.from_numpy(target_batch[:, 0]))
                    target_batch_left = to_var(
                        torch.from_numpy(target_batch[:, 1]))

                    pos_loss_right = vae.mse(output_r[:, :3],
                                             target_batch_right[:, :3])
                    ori_loss_right = vae.rotation_loss(
                        output_r[:, 3:], target_batch_right[:, 3:])

                    pos_loss_left = vae.mse(output_l[:, :3],
                                            target_batch_left[:, :3])
                    ori_loss_left = vae.rotation_loss(output_l[:, 3:],
                                                      target_batch_left[:, 3:])

                    pos_loss = pos_loss_left + pos_loss_right
                    ori_loss = ori_loss_left + ori_loss_right
                elif args.skill_type == 'pull':
                    target_batch = to_var(
                        torch.from_numpy(target_batch.squeeze()))

                    #TODO add flags for when we're training both arms
                    # output = recon_mu[0]  # right arm is index [0]
                    # output = recon_mu[1]  # left arm is index [1]

                    pos_loss_right = vae.mse(output_r[:, :3],
                                             target_batch[:, :3])
                    ori_loss_right = vae.rotation_loss(output_r[:, 3:],
                                                       target_batch[:, 3:])

                    pos_loss = pos_loss_right
                    ori_loss = ori_loss_right

            elif args.task == 'goal':
                target_batch = to_var(torch.from_numpy(target_batch.squeeze()))
                output = recon_mu
                if args.goal_rep == 'pose':
                    pos_loss = vae.mse(output[:, :3], target_batch[:, :3])
                    ori_loss = vae.rotation_loss(output[:, 3:],
                                                 target_batch[:, 3:])
                elif args.goal_rep == 'keypoints':
                    pos_loss = vae.mse(output, target_batch)
                    ori_loss = torch.zeros(pos_loss.shape)

            elif args.task == 'transformation':
                target_batch = to_var(torch.from_numpy(target_batch.squeeze()))
                output = recon_mu
                pos_loss = vae.mse(output[:, :3], target_batch[:, :3])
                ori_loss = vae.rotation_loss(output[:, 3:], target_batch[:,
                                                                         3:])

            recon_loss = pos_loss + ori_loss

            loss = kl_coeff * kl_loss + recon_loss
            loss.backward()
            vae.optimizer.step()

            epoch_total_loss = epoch_total_loss + loss.data
            epoch_kl_loss = epoch_kl_loss + kl_loss.data
            epoch_pos_loss = epoch_pos_loss + pos_loss.data
            epoch_ori_loss = epoch_ori_loss + ori_loss.data
            epoch_recon_loss = epoch_recon_loss + recon_loss.data

            writer.add_scalar('loss/train/ori_loss', ori_loss.data, i)
            writer.add_scalar('loss/train/pos_loss', pos_loss.data, i)
            writer.add_scalar('loss/train/kl_loss', kl_loss.data, i)

            if (i / batch_size) % args.batch_freq == 0:
                if args.skill_type == 'pull' or args.task == 'goal' or args.task == 'transformation':
                    print(
                        'Train Epoch: %d [%d/%d (%f)]\tLoss: %f\tKL: %f\tPos: %f\t Ori: %f'
                        % (epoch, i, dataset_size,
                           100.0 * i / dataset_size / batch_size, loss.item(),
                           kl_loss.item(), pos_loss.item(), ori_loss.item()))
                elif args.skill_type == 'grasp' and args.task == 'contact':
                    print(
                        'Train Epoch: %d [%d/%d (%f)]\tLoss: %f\tKL: %f\tR Pos: %f\t R Ori: %f\tL Pos: %f\tL Ori: %f'
                        % (epoch, i, dataset_size, 100.0 * i / dataset_size /
                           batch_size, loss.item(), kl_loss.item(),
                           pos_loss_right.item(), ori_loss_right.item(),
                           pos_loss_left.item(), ori_loss_left.item()))
        print(' --avgerage loss: ')
        print(epoch_total_loss / (dataset_size / batch_size))
        loss_dict = {
            'epoch_total': epoch_total_loss / (dataset_size / batch_size),
            'epoch_kl': epoch_kl_loss / (dataset_size / batch_size),
            'epoch_pos': epoch_pos_loss / (dataset_size / batch_size),
            'epoch_ori': epoch_ori_loss / (dataset_size / batch_size),
            'epoch_recon': epoch_recon_loss / (dataset_size / batch_size)
        }
        total_loss.append(loss_dict)

        if epoch % args.save_freq == 0:
            print('\n--Saving model\n')
            print('time: ' + str(time.time() - start_time))

            save_state(net=vae,
                       torch_seed=torch_seed,
                       np_seed=np_seed,
                       args=args,
                       fname=os.path.join(
                           trained_model_path,
                           args.model_name + '_epoch_' + str(epoch) + '.pt'))

            np.savez(os.path.join(
                trained_model_path,
                args.model_name + '_epoch_' + str(epoch) + '_loss.npz'),
                     loss=np.asarray(total_loss))

    print('Done!')
    save_state(net=vae,
               torch_seed=torch_seed,
               np_seed=np_seed,
               args=args,
               fname=os.path.join(
                   trained_model_path,
                   args.model_name + '_epoch_' + str(epoch) + '.pt'))
Пример #2
0
class ReconstructionBERTTrainer:
    """
    BERTTrainer make the pretrained BERT model with two LM training method.

        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction

    please check the details on README.md with simple example.

    """
    def __init__(self,
                 bert: BERT,
                 vocab_size: int,
                 markdown_vocab_size,
                 markdown_emb_size,
                 train_dataloader: DataLoader,
                 test_dataloader: DataLoader,
                 lr: float = 1e-4,
                 betas=(0.9, 0.999),
                 weight_decay: float = 0.01,
                 warmup_steps=10000,
                 with_cuda: bool = True,
                 cuda_devices=None,
                 log_freq: int = 10,
                 pad_index=0,
                 loss_lambda=1,
                 model_path=None,
                 n_topics=50,
                 weak_supervise=False,
                 context=False,
                 markdown=False,
                 hinge_loss_start_point=20,
                 entropy_start_point=30):
        """
        :param bert: BERT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        :param context: use information from neighbor cells
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        self.loss_lambda = loss_lambda
        self.n_topics = n_topics
        self.weak_supervise = weak_supervise
        self.context = context
        self.markdown = markdown
        self.hinge_loss_start_point = hinge_loss_start_point
        self.entropy_start_point = entropy_start_point

        cuda_condition = torch.cuda.is_available() and with_cuda

        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        # This BERT model will be saved every epoch
        self.bert = bert
        # Initialize the BERT Language Model, with BERT model
        self.model = VAE(bert,
                         vocab_size,
                         markdown_vocab_size,
                         markdown_emb_size,
                         n_topics=n_topics,
                         weak_supervise=weak_supervise,
                         context=context,
                         markdown=markdown).to(self.device)
        if model_path:
            self.model.load_state_dict(
                torch.load(model_path)["model_state_dict"])
            last_epoch = int(model_path.split('.')[-1][2:])
            self.last_epoch = last_epoch

        else:
            self.last_epoch = None
            # raise NotImplementedError
            # pdb.set_trace()
            # Distributed GPU training if CUDA can detect more than 1 GPU
        if with_cuda and torch.cuda.device_count() > 1:
            # pdb.set_trace()
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
        # pdb.set_trace()
        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.pad_index = pad_index
        # Setting the Adam optimizer with hyper-param
        # self.optim = Adam(self.model.parameters(), lr=lr,
        #                   betas=betas, weight_decay=weight_decay)
        # self.optim_schedule = ScheduledOptim(
        #     self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
        self.optim = SGD(self.model.parameters(), lr=lr, momentum=0.9)
        if self.last_epoch and self.last_epoch >= self.hinge_loss_start_point:
            self.optim = SGD(self.model.parameters(), lr=0.00002, momentum=0.9)

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        # self.criterion = nn.NLLLoss(ignore_index=self.pad_index)
        self.best_loss = None
        self.updated = False
        self.log_freq = log_freq
        self.cross_entropy = nn.CrossEntropyLoss(ignore_index=0)

        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.model.train()
        # self.optim.zero_grad()
        return self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.model.eval()
        with torch.no_grad():
            loss = self.iteration(epoch, self.test_data, train=False)
        return loss

    def api(self, data_loader=None):
        self.model.eval()
        # str_code = "train" if train else "test"
        if not data_loader:
            data_loader = self.test_data

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            # desc="EP_%s:%d" % (str_code, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        # for (i, data), (ni, ndata) in data_iter, neg_data_iter:
        phases = []
        stages = []
        stage_vecs = []
        with torch.no_grad():
            for i, item in data_iter:
                data = item[0]
                ndata = item[1]
                data = {
                    key: value.to(self.device)
                    for key, value in data.items()
                }
                ndata = {
                    key: value.to(self.device)
                    for key, value in ndata.items()
                }

                # 0. batch_data will be sent into the device(GPU or cpu)
                data = {
                    key: value.to(self.device)
                    for key, value in data.items()
                }
                ndata = {
                    key: value.to(self.device)
                    for key, value in ndata.items()
                }
                # pdb.set_trace()
                # 1. forward the next_sentence_prediction and masked_lm model
                # pdb.set_trace()
                reconstructed_vec, graph_vec, origin_neg, topic_dist, stage_vec = self.model.forward(
                    data["bert_input"],
                    ndata["bert_input"],
                    data["segment_label"],
                    ndata["segment_label"],
                    data["adj_mat"],
                    ndata["adj_mat"],
                    train=False,
                    context_topic_dist=data["context_topic_vec"],
                    markdown_label=data["markdown_label"],
                    markdown_len=data["markdown_len"],
                    neg_markdown_label=ndata["markdown_label"],
                    neg_markdown_len=ndata["markdown_len"])
                data_loader.dataset.update_topic_dist(topic_dist, data["id"])

                phases += torch.max(topic_dist, 1)[-1].tolist()
                # print(torch.max(stage_vec, 1)[-1].tolist())
                stages += torch.max(stage_vec, 1)[-1].tolist()
                stage_vecs += stage_vec.tolist()
                # pdb.set_trace()
        return phases, stages, stage_vecs

    def iteration(self, epoch, data_loader, train=True):
        """
        loop over the data_loader for training or testing
        if on train status, backward operation is activated
        and also auto save the model every peoch

        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value of is train or test
        :return: None
        """
        str_code = "train" if train else "test"

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0

        # def calculate_iter(data):

        for i, item in data_iter:
            data = item[0]
            ndata = item[1]

            data = {key: value.to(self.device) for key, value in data.items()}
            ndata = {
                key: value.to(self.device)
                for key, value in ndata.items()
            }

            # 1. forward the next_sentence_prediction and masked_lm model

            reconstructed_vec, graph_vec, origin_neg, topic_dist, stage_vec = self.model.forward(
                data["bert_input"],
                ndata["bert_input"],
                data["segment_label"],
                ndata["segment_label"],
                data["adj_mat"],
                ndata["adj_mat"],
                train=train,
                context_topic_dist=data["context_topic_vec"],
                markdown_label=data["markdown_label"],
                markdown_len=data["markdown_len"],
                neg_markdown_label=ndata["markdown_label"],
                neg_markdown_len=ndata["markdown_len"])
            # pdb.set_trace()
            if self.context:
                data_loader.dataset.update_topic_dist(topic_dist, data["id"])
            bs, hid_size = reconstructed_vec.shape
            nbs, hid_size = origin_neg.shape
            duplicate = int(nbs / bs)
            # pdb.set_trace()
            # if str_code == 'test':
            #     pdb.set_trace()
            hinge_loss = my_loss(reconstructed_vec, graph_vec, origin_neg)
            weight_loss = torch.norm(
                torch.mm(self.model.reconstruction.weight.T,
                         self.model.reconstruction.weight) -
                torch.eye(self.n_topics).cuda())
            loss = self.loss_lambda * weight_loss + hinge_loss
            # if self.weak_supervise:

            c_entropy = self.cross_entropy(stage_vec, data['stage'])

            entropy = -1 * (F.softmax(stage_vec, dim=1) *
                            F.log_softmax(stage_vec, dim=1)).sum()

            loss += 2 * c_entropy  # + 0.001 * entropy
            if epoch < self.hinge_loss_start_point:
                loss = c_entropy
            # else:
            elif epoch < self.entropy_start_point:
                loss = c_entropy + self.loss_lambda * weight_loss + hinge_loss
            else:
                loss = c_entropy + entropy + self.loss_lambda * weight_loss + hinge_loss

            if epoch == self.hinge_loss_start_point:
                self.optim = SGD(self.model.parameters(),
                                 lr=0.00002,
                                 momentum=0.9)

            # 3. backward and optimization only in train

            if train:
                self.optim.zero_grad()
                loss.backward()
                # self.optim.step_and_update_lr()
                self.optim.step()

            avg_loss += loss.item()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                # "avg_acc": total_correct / total_element * 100,
                "loss": loss.item(),
                "cross_entropy": c_entropy.item(),
                "entropy": entropy.item(),
                "hinge_loss": hinge_loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code),
              avg_loss / len(data_iter))
        return avg_loss / len(data_iter)

    def save(self, epoch, file_path="output/bert_trained.model"):
        """
        Saving the current BERT model on file_path

        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        output_path = file_path + ".ep%d" % epoch
        # if self.updated:
        #     return output_path
        # torch.save(self.bert.cpu(), output_path)
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': self.model.state_dict()
                # 'optimizer_state_dict': optimizer.state_dict(),
                # 'loss': loss,
                # ...
            },
            output_path)
        # self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        # self.updated = True
        return output_path
Пример #3
0
class AudioToBodyDynamics(object):
    """
    Defines a wrapper class for training and evaluating a model.
    Inputs:
           args    (argparse object):      model settings
           generator (tuple DataLoader):   a tuple of at least one DataLoader
    """
    def __init__(self, args, generator, freestyle=False):
        # TODO
        super(AudioToBodyDynamics, self).__init__()
        self.device = args.device
        self.log_frequency = args.log_frequency

        self.is_freestyle_mode = freestyle

        self.generator = generator
        self.model_name = args.model_name
        self.ident = args.ident
        self.model_name = args.model_name

        input_dim, output_dim = generator[0].dataset.getDimsPerBatch()

        model_options = {
            'seq_len': args.seq_len,
            'device': args.device,
            'dropout': args.dp,
            'batch_size': args.batch_size,
            'hidden_dim': args.hidden_size,
            'input_dim': input_dim,
            'output_dim': output_dim,
            'trainable_init': args.trainable_init
        }

        if args.model_name == "AudioToJointsThree":
            from model import AudioToJointsThree
            self.model = AudioToJointsThree(model_options).cuda(args.device)
        elif args.model_name == 'AudioToJointsNonlinear':
            from model import AudioToJointsNonlinear
            self.model = AudioToJointsNonlinear(model_options).cuda(
                args.device)
        elif args.model_name == "AudioToJoints":
            from model import AudioToJoints
            self.model = AudioToJoints(model_options).cuda(args.device)
        elif args.model_name == 'JointsToJoints':
            from model import JointsToJoints
            self.model = JointsToJoints(model_options).cuda(
                args.device).double()
        elif args.model_name == 'LSTMToDense':
            from model import LSTMToDense
            self.model = LSTMToDense(model_options).cuda(args.device).double()
        elif args.model_name == 'AudioToJointsSeq2Seq':
            from model import AudioToJointsSeq2Seq
            self.model = AudioToJointsSeq2Seq(model_options).cuda(
                args.device).double()
        elif args.model_name == 'MDNRNN':
            from model import MDNRNN
            self.model = MDNRNN(model_options).cuda(args.device).double()
        elif args.model_name == 'VAE':
            from model import VAE
            self.model = VAE(model_options).cuda(args.device).double()

        # construct the model
        self.optim = optim.Adam(self.model.parameters(), lr=args.lr)

        # Load checkpoint model
        if self.is_freestyle_mode:
            path = f"{model_dir}{args.model_name}_{str(args.ident)}.pth"
            print(path)
            self.loadModelCheckpoint(path)

    # general loss function
    def buildLoss(self, predictions, targets):
        square_diff = (predictions - targets)**2
        out = torch.sum(square_diff, -1, keepdim=True)
        return torch.mean(out)

    def mdn_loss(self, y, pi, mu, sigma):
        m = torch.distributions.Normal(loc=mu, scale=sigma)
        loss = torch.exp(m.log_prob(y))
        loss = torch.sum(loss * pi, dim=2)
        loss = -torch.log(loss)
        return torch.mean(loss)

    # Loss function from https://github.com/pytorch/examples/blob/master/vae/main.py,
    # Appendix B of https://github.com/pytorch/examples/blob/master/vae/main.py
    def vae_loss(self, targets, recon_targets, mu, logvar):
        BCE = nn.functional.binary_cross_entropy(recon_targets,
                                                 targets,
                                                 reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

    def saveModel(self, state_info, path):
        torch.save(state_info, path)

    def loadModelCheckpoint(self, path):
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optim.load_state_dict(checkpoint['optim_state_dict'])

    def runNetwork(self, inputs, targets):
        """
        Train on one given mfcc pose pair
        Args:
             inputs (array): [batch, seq_len, mfcc_features * 3]
             targets (array): [batch, seq_len, 19 * 2 poses]
        Returns:
             predictions, truth, loss
        """
        def to_numpy(x):
            # import from gpu device to cpu, convert to numpy
            return x.cpu().data.numpy()

        inputs = Variable(torch.DoubleTensor(inputs.double()).to(self.device))

        # reshape targets into (batch * seq_len, input features)
        targets = Variable(torch.DoubleTensor(targets).to(self.device))

        if self.model_name == 'AudioToJointsSeq2Seq':
            predictions = self.model.forward(inputs, targets)
        elif self.model_name == 'VAE':
            predictions, mu, logvar = self.model.forward(inputs)
        else:
            predictions = self.model.forward(inputs)

        criterion = nn.L1Loss()
        if self.model_name == 'AudioToJointsSeq2Seq':
            loss = criterion(predictions.to(self.device),
                             targets.to(self.device).float())
        elif self.model_name == 'MDNRNN':
            # predictions = (pi, mu, sigma), (h, c)
            loss = self.mdn_loss(targets, predictions[0][0], predictions[0][1],
                                 predictions[0][2])
        elif self.model_name == 'VAE':
            loss = self.vae_loss(targets, predictions, mu, logvar)
        else:
            loss = criterion(predictions, targets)
        return (to_numpy(predictions), to_numpy(targets)), loss

    def runEpoch(self):
        # given one epoch
        train_losses = []  #coeff_losses
        val_losses = []
        predictions, targets = [], []

        if not self.is_freestyle_mode:  # train
            # for each data point
            for mfccs, poses in self.generator[0]:
                self.model.train()  # pass train flag to model

                pred_targs, train_loss = self.runNetwork(mfccs, poses)
                self.optim.zero_grad()
                train_loss.backward()
                self.optim.step()
                train_loss = train_loss.data.tolist()
                train_losses.append(train_loss)

            # validation loss
            for mfccs, poses in self.generator[1]:
                self.model.eval()
                pred_targs, val_loss = self.runNetwork(mfccs, poses)

                val_loss = val_loss.data.tolist()
                val_losses.append(val_loss)
                pred = pred_targs[0].reshape(
                    int(pred_targs[0].shape[0] * pred_targs[0].shape[1]), 19,
                    2)
                predictions.append(pred)
                targets.append(pred_targs[1])

        # test or predict / play w/ model
        if self.is_freestyle_mode:
            for mfccs, poses in self.generator[0]:
                self.model.eval()
                # mfccs = mfccs.float()
                pred_targs, val_loss = self.runNetwork(mfccs, poses)
                val_loss = val_loss.data.tolist()
                val_losses.append(val_loss)
                pred = pred_targs[0].reshape(
                    int(pred_targs[0].shape[0] * pred_targs[0].shape[1]), 19,
                    2)
                predictions.append(pred)
                targets.append(pred_targs[1])

        return train_losses, val_losses, predictions, targets

    def trainModel(self, max_epochs, logfldr, model_dir):
        # TODO
        log.debug("Training model")
        epoch_losses = []
        batch_losses = []
        val_losses = []
        i, best_loss, iters_without_improvement = 0, float('inf'), 0
        best_train_loss, best_val_loss = float('inf'), float('inf')

        if logfldr:
            if logfldr[-1] != '/':
                logfldr += '/'
        filename = f'{logfldr}epoch_of_model_{str(self.ident)}.txt'
        state_info = {
            'epoch': i,
            'epoch_losses': epoch_losses,
            'batch_losses': batch_losses,
            'validation_losses': val_losses,
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optim.state_dict(),
        }

        for i in range(max_epochs):
            if int(i / 10) == 0:
                if i == 0:
                    with open(filename, 'w') as f:
                        f.write(f"Epoch: {i} started\n")
                else:
                    with open(filename, 'a+') as f:
                        f.write(f"Epoch: {i} started\n")
                # save the model
                if model_dir:
                    if model_dir[-1] != '/':
                        model_dir += '/'
                path = f"{model_dir}{self.model_name}_{str(self.ident)}.pth"
                self.saveModel(state_info, path)

            # train_info, val_info, predictions, targets
            iter_train, iter_val, predictions, targets = self.runEpoch()

            iter_mean = np.mean(iter_train)
            iter_val_mean = np.mean(iter_val)
            # iter_val_mean = np.mean(iter_val[0]), np.mean(iter_val[1])

            epoch_losses.append(iter_mean)
            batch_losses.extend(iter_train)
            val_losses.append(iter_val_mean)

            log.info("Epoch {} / {}".format(i, max_epochs))
            log.info(f"Training Loss : {iter_mean}")
            log.info(f"Validation Loss : {iter_val_mean}")

            best_train_loss = iter_mean if iter_mean < best_train_loss else best_train_loss
            best_val_loss = iter_val_mean if iter_val_mean < best_val_loss else best_val_loss

        # Visualize VAE latent space
        if self.model_name == 'VAE':
            self.vae_plot()

        self.plotResults(logfldr, epoch_losses, batch_losses, val_losses)
        path = f"{model_dir}{self.model_name}_{str(self.ident)}.pth"
        self.saveModel(state_info, path)
        return best_train_loss, best_val_loss

    # plot random subset of poses in VAE latent space
    def vae_plot(self):
        z_list = torch.Tensor(1, 2)
        poses = []
        for input, output in self.generator:
            for inp in input:
                poses.append(inp)
            mu, logvar = self.model.encode(input)
            z = self.model.reparameterize(mu, logvar)
            z2 = z[:, -1, :]
            z_list = torch.cat((z_list.double(), z2.double()), 0)

        indices = np.random.randint(low=1, high=z_list.shape[0], size=1000)
        coords = np.array([z_list[ind, :].detach().numpy() for ind in indices])

        # # k-means clustering for coloring
        # kmeans = KMeans(n_clusters=5).fit(coords)
        # y_kmeans = kmeans.predict(coords)
        # plt.scatter(coords[:,0], coords[:,1], c=y_kmeans, cmap='viridis')
        # plt.show()
        #
        # # draw each mean pose
        # centers = kmeans.cluster_centers_
        # recons = [self.model.decode(torch.from_numpy(center)).detach().numpy().reshape(19,2) for center in centers]

        # k-medoids clustering for coloring
        kmedoids = KMedoids(n_clusters=5).fit(coords)
        y_kmedoids = kmedoids.predict(coords)
        plt.scatter(coords[:, 0], coords[:, 1], c=y_kmedoids, cmap='viridis')
        plt.show()

        recons = []
        for center in kmedoids.cluster_centers_:
            c = np.array(center)
            for i in range(len(coords)):
                if np.array_equal(c, coords[i]):
                    recons.append(poses[indices[i] -
                                        1].detach().numpy().reshape(19, 2))

        self.draw_poses(np.array(recons))

    # Takes in np array of poses that are each 19x2 arrays
    def draw_poses(self, poses):
        count = 0
        shift_by = np.array([750, 800]) - poses[0][8]
        poses += shift_by
        for pose in poses:
            person_id = str(0) + ", " + str([0])
            canvas = draw_pose_figure(person_id, pose)
            file_name = "images/" + f"{count:05}.jpg"
            cv2.imwrite(file_name, canvas)
            count += 1

    def plotResults(self, logfldr, epoch_losses, batch_losses, val_losses):
        losses = [epoch_losses, batch_losses, val_losses]
        names = [["Epoch loss"], ["Batch loss"], ["Val loss"]]
        _, ax = plt.subplots(nrows=len(losses), ncols=1)
        for index, pair in enumerate(zip(losses, names)):
            data = [pair[0][j] for j in range(len(pair[0]))]
            ax[index].plot(data, label=pair[1])
            ax[index].legend()
        if logfldr:
            if logfldr[-1] != '/':
                logfldr += '/'
        save_filename = os.path.join(
            logfldr, f"{self.model_name}_{str(self.ident)}_results.png")
        plt.savefig(save_filename)
        plt.close()